POC2PROD / frontend.py
maxcasado's picture
Update frontend.py
9a5f5dd verified
import streamlit as st
import pandas as pd
from io import StringIO
from model_utils import predict_proba
st.set_page_config(page_title="StackOverflow Tagger", layout="wide")
st.title("🔖 StackOverflow Tag Predictor")
# ---- Choix du modèle dans la sidebar ----
MODEL_OPTIONS = {
"BERT Overflow (maxcasado/BERT_overflow)": "maxcasado/BERT_overflow",
"Wendy Tags (wendyserver/predict_tags)": "wendyserver/predict_tags",
}
st.sidebar.header("⚙️ Configuration")
model_label = st.sidebar.selectbox(
"Choisir le modèle",
list(MODEL_OPTIONS.keys()),
)
selected_model = MODEL_OPTIONS[model_label]
st.sidebar.write(f"Modèle sélectionné : `{selected_model}`")
# ---- Tabs : single question / CSV ----
tab_single, tab_csv = st.tabs(["Question unique", "CSV batch"])
with tab_single:
st.write(
"Entrez une question (titre + éventuellement description) "
"et récupérez les probabilités des tags StackOverflow prédits par le modèle."
)
question = st.text_area(
"Question StackOverflow",
height=200,
placeholder="Ex: How to fine-tune BERT for multi-label classification?",
)
top_k = st.slider(
"Nombre de tags à afficher (top_k)",
1,
20,
5,
key="topk_single",
)
if st.button("Prédire", key="predict_single"):
if not question.strip():
st.warning("Merci d'entrer une question.")
else:
with st.spinner(f"Prédiction en cours avec {selected_model}..."):
tags = predict_proba(
question,
top_k=top_k,
model_name=selected_model,
)
if not tags:
st.warning("Pas de tags prédits.")
else:
st.subheader("Résultats")
for t in tags:
st.write(f"- **{t['label']}** — probabilité : `{t['score']:.4f}`")
st.subheader("Distribution des probabilités (top_k)")
scores = {t["label"]: t["score"] for t in tags}
st.bar_chart(scores)
with tab_csv:
st.write(
"Uploade un fichier CSV contenant des questions. "
"On ajoutera une colonne avec le tag principal prédit pour chaque ligne."
)
uploaded_file = st.file_uploader("Choisir un fichier CSV", type=["csv"])
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.write("Aperçu du CSV :")
st.dataframe(df.head())
text_column = st.selectbox(
"Colonne contenant la question",
options=list(df.columns),
)
top_k_batch = st.slider(
"Nombre de tags à considérer (pour choisir le meilleur)",
1,
20,
5,
key="topk_batch",
)
if st.button("Lancer la prédiction sur le CSV"):
if df[text_column].isnull().all():
st.error("La colonne choisie ne contient pas de texte.")
else:
preds_best_tag = []
preds_best_score = []
with st.spinner(f"Prédiction batch avec {selected_model}..."):
for text in df[text_column].fillna(""):
s = str(text).strip()
if not s:
preds_best_tag.append(None)
preds_best_score.append(None)
continue
tags = predict_proba(
s,
top_k=top_k_batch,
model_name=selected_model,
)
if len(tags) == 0:
preds_best_tag.append(None)
preds_best_score.append(None)
else:
best = tags[0]
preds_best_tag.append(best["label"])
preds_best_score.append(best["score"])
df["predicted_tag"] = preds_best_tag
df["predicted_score"] = preds_best_score
st.subheader("Résultats enrichis")
st.dataframe(df.head())
csv_buffer = StringIO()
df.to_csv(csv_buffer, index=False)
csv_bytes = csv_buffer.getvalue().encode("utf-8")
st.download_button(
label="📥 Télécharger le CSV avec tags prédits",
data=csv_bytes,
file_name="questions_with_tags.csv",
mime="text/csv",
)
else:
st.info("Uploade un fichier CSV pour lancer la prédiction batch.")