File size: 4,770 Bytes
c7e64f8 883f14c c7e64f8 883f14c c7e64f8 9a5f5dd 883f14c c7e64f8 883f14c 926a499 883f14c 926a499 9a5f5dd 883f14c 9a5f5dd 926a499 883f14c 926a499 883f14c 9a5f5dd 883f14c 9a5f5dd 883f14c 9a5f5dd 883f14c 9a5f5dd 883f14c 9a5f5dd 883f14c 9a5f5dd 883f14c 9a5f5dd 883f14c fc989ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import streamlit as st
import pandas as pd
from io import StringIO
from model_utils import predict_proba
st.set_page_config(page_title="StackOverflow Tagger", layout="wide")
st.title("🔖 StackOverflow Tag Predictor")
# ---- Choix du modèle dans la sidebar ----
MODEL_OPTIONS = {
"BERT Overflow (maxcasado/BERT_overflow)": "maxcasado/BERT_overflow",
"Wendy Tags (wendyserver/predict_tags)": "wendyserver/predict_tags",
}
st.sidebar.header("⚙️ Configuration")
model_label = st.sidebar.selectbox(
"Choisir le modèle",
list(MODEL_OPTIONS.keys()),
)
selected_model = MODEL_OPTIONS[model_label]
st.sidebar.write(f"Modèle sélectionné : `{selected_model}`")
# ---- Tabs : single question / CSV ----
tab_single, tab_csv = st.tabs(["Question unique", "CSV batch"])
with tab_single:
st.write(
"Entrez une question (titre + éventuellement description) "
"et récupérez les probabilités des tags StackOverflow prédits par le modèle."
)
question = st.text_area(
"Question StackOverflow",
height=200,
placeholder="Ex: How to fine-tune BERT for multi-label classification?",
)
top_k = st.slider(
"Nombre de tags à afficher (top_k)",
1,
20,
5,
key="topk_single",
)
if st.button("Prédire", key="predict_single"):
if not question.strip():
st.warning("Merci d'entrer une question.")
else:
with st.spinner(f"Prédiction en cours avec {selected_model}..."):
tags = predict_proba(
question,
top_k=top_k,
model_name=selected_model,
)
if not tags:
st.warning("Pas de tags prédits.")
else:
st.subheader("Résultats")
for t in tags:
st.write(f"- **{t['label']}** — probabilité : `{t['score']:.4f}`")
st.subheader("Distribution des probabilités (top_k)")
scores = {t["label"]: t["score"] for t in tags}
st.bar_chart(scores)
with tab_csv:
st.write(
"Uploade un fichier CSV contenant des questions. "
"On ajoutera une colonne avec le tag principal prédit pour chaque ligne."
)
uploaded_file = st.file_uploader("Choisir un fichier CSV", type=["csv"])
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.write("Aperçu du CSV :")
st.dataframe(df.head())
text_column = st.selectbox(
"Colonne contenant la question",
options=list(df.columns),
)
top_k_batch = st.slider(
"Nombre de tags à considérer (pour choisir le meilleur)",
1,
20,
5,
key="topk_batch",
)
if st.button("Lancer la prédiction sur le CSV"):
if df[text_column].isnull().all():
st.error("La colonne choisie ne contient pas de texte.")
else:
preds_best_tag = []
preds_best_score = []
with st.spinner(f"Prédiction batch avec {selected_model}..."):
for text in df[text_column].fillna(""):
s = str(text).strip()
if not s:
preds_best_tag.append(None)
preds_best_score.append(None)
continue
tags = predict_proba(
s,
top_k=top_k_batch,
model_name=selected_model,
)
if len(tags) == 0:
preds_best_tag.append(None)
preds_best_score.append(None)
else:
best = tags[0]
preds_best_tag.append(best["label"])
preds_best_score.append(best["score"])
df["predicted_tag"] = preds_best_tag
df["predicted_score"] = preds_best_score
st.subheader("Résultats enrichis")
st.dataframe(df.head())
csv_buffer = StringIO()
df.to_csv(csv_buffer, index=False)
csv_bytes = csv_buffer.getvalue().encode("utf-8")
st.download_button(
label="📥 Télécharger le CSV avec tags prédits",
data=csv_bytes,
file_name="questions_with_tags.csv",
mime="text/csv",
)
else:
st.info("Uploade un fichier CSV pour lancer la prédiction batch.")
|