File size: 4,770 Bytes
c7e64f8
883f14c
 
c7e64f8
883f14c
c7e64f8
 
 
 
9a5f5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883f14c
c7e64f8
883f14c
 
 
 
 
926a499
883f14c
 
 
 
 
926a499
9a5f5dd
 
 
 
 
 
 
883f14c
 
 
 
 
9a5f5dd
 
 
 
 
 
926a499
 
883f14c
926a499
 
 
 
 
 
 
 
 
883f14c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5f5dd
883f14c
 
 
 
 
 
 
 
 
 
9a5f5dd
 
 
 
883f14c
9a5f5dd
 
 
 
883f14c
9a5f5dd
 
 
 
 
 
 
883f14c
9a5f5dd
 
883f14c
 
9a5f5dd
 
883f14c
9a5f5dd
 
883f14c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc989ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
import pandas as pd
from io import StringIO

from model_utils import predict_proba

st.set_page_config(page_title="StackOverflow Tagger", layout="wide")
st.title("🔖 StackOverflow Tag Predictor")

# ---- Choix du modèle dans la sidebar ----

MODEL_OPTIONS = {
    "BERT Overflow (maxcasado/BERT_overflow)": "maxcasado/BERT_overflow",
    "Wendy Tags (wendyserver/predict_tags)": "wendyserver/predict_tags",
}

st.sidebar.header("⚙️ Configuration")
model_label = st.sidebar.selectbox(
    "Choisir le modèle",
    list(MODEL_OPTIONS.keys()),
)
selected_model = MODEL_OPTIONS[model_label]

st.sidebar.write(f"Modèle sélectionné : `{selected_model}`")

# ---- Tabs : single question / CSV ----

tab_single, tab_csv = st.tabs(["Question unique", "CSV batch"])

with tab_single:
    st.write(
        "Entrez une question (titre + éventuellement description) "
        "et récupérez les probabilités des tags StackOverflow prédits par le modèle."
    )

    question = st.text_area(
        "Question StackOverflow",
        height=200,
        placeholder="Ex: How to fine-tune BERT for multi-label classification?",
    )

    top_k = st.slider(
        "Nombre de tags à afficher (top_k)",
        1,
        20,
        5,
        key="topk_single",
    )

    if st.button("Prédire", key="predict_single"):
        if not question.strip():
            st.warning("Merci d'entrer une question.")
        else:
            with st.spinner(f"Prédiction en cours avec {selected_model}..."):
                tags = predict_proba(
                    question,
                    top_k=top_k,
                    model_name=selected_model,
                )

            if not tags:
                st.warning("Pas de tags prédits.")
            else:
                st.subheader("Résultats")
                for t in tags:
                    st.write(f"- **{t['label']}** — probabilité : `{t['score']:.4f}`")

                st.subheader("Distribution des probabilités (top_k)")
                scores = {t["label"]: t["score"] for t in tags}
                st.bar_chart(scores)

with tab_csv:
    st.write(
        "Uploade un fichier CSV contenant des questions. "
        "On ajoutera une colonne avec le tag principal prédit pour chaque ligne."
    )

    uploaded_file = st.file_uploader("Choisir un fichier CSV", type=["csv"])

    if uploaded_file is not None:
        df = pd.read_csv(uploaded_file)

        st.write("Aperçu du CSV :")
        st.dataframe(df.head())

        text_column = st.selectbox(
            "Colonne contenant la question",
            options=list(df.columns),
        )

        top_k_batch = st.slider(
            "Nombre de tags à considérer (pour choisir le meilleur)",
            1,
            20,
            5,
            key="topk_batch",
        )

        if st.button("Lancer la prédiction sur le CSV"):
            if df[text_column].isnull().all():
                st.error("La colonne choisie ne contient pas de texte.")
            else:
                preds_best_tag = []
                preds_best_score = []

                with st.spinner(f"Prédiction batch avec {selected_model}..."):
                    for text in df[text_column].fillna(""):
                        s = str(text).strip()
                        if not s:
                            preds_best_tag.append(None)
                            preds_best_score.append(None)
                            continue

                        tags = predict_proba(
                            s,
                            top_k=top_k_batch,
                            model_name=selected_model,
                        )

                        if len(tags) == 0:
                            preds_best_tag.append(None)
                            preds_best_score.append(None)
                        else:
                            best = tags[0]
                            preds_best_tag.append(best["label"])
                            preds_best_score.append(best["score"])

                df["predicted_tag"] = preds_best_tag
                df["predicted_score"] = preds_best_score

                st.subheader("Résultats enrichis")
                st.dataframe(df.head())

                csv_buffer = StringIO()
                df.to_csv(csv_buffer, index=False)
                csv_bytes = csv_buffer.getvalue().encode("utf-8")

                st.download_button(
                    label="📥 Télécharger le CSV avec tags prédits",
                    data=csv_bytes,
                    file_name="questions_with_tags.csv",
                    mime="text/csv",
                )
    else:
        st.info("Uploade un fichier CSV pour lancer la prédiction batch.")