|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
from io import StringIO |
|
|
|
|
|
from model_utils import predict_proba |
|
|
|
|
|
st.set_page_config(page_title="StackOverflow Tagger", layout="wide") |
|
|
st.title("🔖 StackOverflow Tag Predictor") |
|
|
|
|
|
|
|
|
|
|
|
MODEL_OPTIONS = { |
|
|
"BERT Overflow (maxcasado/BERT_overflow)": "maxcasado/BERT_overflow", |
|
|
"Wendy Tags (wendyserver/predict_tags)": "wendyserver/predict_tags", |
|
|
} |
|
|
|
|
|
st.sidebar.header("⚙️ Configuration") |
|
|
model_label = st.sidebar.selectbox( |
|
|
"Choisir le modèle", |
|
|
list(MODEL_OPTIONS.keys()), |
|
|
) |
|
|
selected_model = MODEL_OPTIONS[model_label] |
|
|
|
|
|
st.sidebar.write(f"Modèle sélectionné : `{selected_model}`") |
|
|
|
|
|
|
|
|
|
|
|
tab_single, tab_csv = st.tabs(["Question unique", "CSV batch"]) |
|
|
|
|
|
with tab_single: |
|
|
st.write( |
|
|
"Entrez une question (titre + éventuellement description) " |
|
|
"et récupérez les probabilités des tags StackOverflow prédits par le modèle." |
|
|
) |
|
|
|
|
|
question = st.text_area( |
|
|
"Question StackOverflow", |
|
|
height=200, |
|
|
placeholder="Ex: How to fine-tune BERT for multi-label classification?", |
|
|
) |
|
|
|
|
|
top_k = st.slider( |
|
|
"Nombre de tags à afficher (top_k)", |
|
|
1, |
|
|
20, |
|
|
5, |
|
|
key="topk_single", |
|
|
) |
|
|
|
|
|
if st.button("Prédire", key="predict_single"): |
|
|
if not question.strip(): |
|
|
st.warning("Merci d'entrer une question.") |
|
|
else: |
|
|
with st.spinner(f"Prédiction en cours avec {selected_model}..."): |
|
|
tags = predict_proba( |
|
|
question, |
|
|
top_k=top_k, |
|
|
model_name=selected_model, |
|
|
) |
|
|
|
|
|
if not tags: |
|
|
st.warning("Pas de tags prédits.") |
|
|
else: |
|
|
st.subheader("Résultats") |
|
|
for t in tags: |
|
|
st.write(f"- **{t['label']}** — probabilité : `{t['score']:.4f}`") |
|
|
|
|
|
st.subheader("Distribution des probabilités (top_k)") |
|
|
scores = {t["label"]: t["score"] for t in tags} |
|
|
st.bar_chart(scores) |
|
|
|
|
|
with tab_csv: |
|
|
st.write( |
|
|
"Uploade un fichier CSV contenant des questions. " |
|
|
"On ajoutera une colonne avec le tag principal prédit pour chaque ligne." |
|
|
) |
|
|
|
|
|
uploaded_file = st.file_uploader("Choisir un fichier CSV", type=["csv"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
df = pd.read_csv(uploaded_file) |
|
|
|
|
|
st.write("Aperçu du CSV :") |
|
|
st.dataframe(df.head()) |
|
|
|
|
|
text_column = st.selectbox( |
|
|
"Colonne contenant la question", |
|
|
options=list(df.columns), |
|
|
) |
|
|
|
|
|
top_k_batch = st.slider( |
|
|
"Nombre de tags à considérer (pour choisir le meilleur)", |
|
|
1, |
|
|
20, |
|
|
5, |
|
|
key="topk_batch", |
|
|
) |
|
|
|
|
|
if st.button("Lancer la prédiction sur le CSV"): |
|
|
if df[text_column].isnull().all(): |
|
|
st.error("La colonne choisie ne contient pas de texte.") |
|
|
else: |
|
|
preds_best_tag = [] |
|
|
preds_best_score = [] |
|
|
|
|
|
with st.spinner(f"Prédiction batch avec {selected_model}..."): |
|
|
for text in df[text_column].fillna(""): |
|
|
s = str(text).strip() |
|
|
if not s: |
|
|
preds_best_tag.append(None) |
|
|
preds_best_score.append(None) |
|
|
continue |
|
|
|
|
|
tags = predict_proba( |
|
|
s, |
|
|
top_k=top_k_batch, |
|
|
model_name=selected_model, |
|
|
) |
|
|
|
|
|
if len(tags) == 0: |
|
|
preds_best_tag.append(None) |
|
|
preds_best_score.append(None) |
|
|
else: |
|
|
best = tags[0] |
|
|
preds_best_tag.append(best["label"]) |
|
|
preds_best_score.append(best["score"]) |
|
|
|
|
|
df["predicted_tag"] = preds_best_tag |
|
|
df["predicted_score"] = preds_best_score |
|
|
|
|
|
st.subheader("Résultats enrichis") |
|
|
st.dataframe(df.head()) |
|
|
|
|
|
csv_buffer = StringIO() |
|
|
df.to_csv(csv_buffer, index=False) |
|
|
csv_bytes = csv_buffer.getvalue().encode("utf-8") |
|
|
|
|
|
st.download_button( |
|
|
label="📥 Télécharger le CSV avec tags prédits", |
|
|
data=csv_bytes, |
|
|
file_name="questions_with_tags.csv", |
|
|
mime="text/csv", |
|
|
) |
|
|
else: |
|
|
st.info("Uploade un fichier CSV pour lancer la prédiction batch.") |
|
|
|