maxcasado commited on
Commit
9a5f5dd
·
verified ·
1 Parent(s): bb37dca

Update frontend.py

Browse files
Files changed (1) hide show
  1. frontend.py +53 -15
frontend.py CHANGED
@@ -7,6 +7,24 @@ from model_utils import predict_proba
7
  st.set_page_config(page_title="StackOverflow Tagger", layout="wide")
8
  st.title("🔖 StackOverflow Tag Predictor")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  tab_single, tab_csv = st.tabs(["Question unique", "CSV batch"])
11
 
12
  with tab_single:
@@ -21,14 +39,24 @@ with tab_single:
21
  placeholder="Ex: How to fine-tune BERT for multi-label classification?",
22
  )
23
 
24
- top_k = st.slider("Nombre de tags à afficher (top_k)", 1, 20, 5, key="topk_single")
 
 
 
 
 
 
25
 
26
  if st.button("Prédire", key="predict_single"):
27
  if not question.strip():
28
  st.warning("Merci d'entrer une question.")
29
  else:
30
- with st.spinner("Prédiction en cours..."):
31
- tags = predict_proba(question, top_k=top_k)
 
 
 
 
32
 
33
  if not tags:
34
  st.warning("Pas de tags prédits.")
@@ -61,7 +89,7 @@ with tab_csv:
61
  )
62
 
63
  top_k_batch = st.slider(
64
- "Nombre de tags à considérer pour le batch (pour choisir le meilleur)",
65
  1,
66
  20,
67
  5,
@@ -72,23 +100,33 @@ with tab_csv:
72
  if df[text_column].isnull().all():
73
  st.error("La colonne choisie ne contient pas de texte.")
74
  else:
75
- preds = []
76
- with st.spinner("Prédiction en cours sur le CSV..."):
 
 
77
  for text in df[text_column].fillna(""):
78
- if not str(text).strip():
79
- preds.append({"best_tag": None, "best_score": None})
 
 
80
  continue
81
- tags = predict_proba(str(text), top_k=top_k_batch)
 
 
 
 
 
 
82
  if len(tags) == 0:
83
- preds.append({"best_tag": None, "best_score": None})
 
84
  else:
85
  best = tags[0]
86
- preds.append(
87
- {"best_tag": best["label"], "best_score": best["score"]}
88
- )
89
 
90
- df["predicted_tag"] = [p["best_tag"] for p in preds]
91
- df["predicted_score"] = [p["best_score"] for p in preds]
92
 
93
  st.subheader("Résultats enrichis")
94
  st.dataframe(df.head())
 
7
  st.set_page_config(page_title="StackOverflow Tagger", layout="wide")
8
  st.title("🔖 StackOverflow Tag Predictor")
9
 
10
+ # ---- Choix du modèle dans la sidebar ----
11
+
12
+ MODEL_OPTIONS = {
13
+ "BERT Overflow (maxcasado/BERT_overflow)": "maxcasado/BERT_overflow",
14
+ "Wendy Tags (wendyserver/predict_tags)": "wendyserver/predict_tags",
15
+ }
16
+
17
+ st.sidebar.header("⚙️ Configuration")
18
+ model_label = st.sidebar.selectbox(
19
+ "Choisir le modèle",
20
+ list(MODEL_OPTIONS.keys()),
21
+ )
22
+ selected_model = MODEL_OPTIONS[model_label]
23
+
24
+ st.sidebar.write(f"Modèle sélectionné : `{selected_model}`")
25
+
26
+ # ---- Tabs : single question / CSV ----
27
+
28
  tab_single, tab_csv = st.tabs(["Question unique", "CSV batch"])
29
 
30
  with tab_single:
 
39
  placeholder="Ex: How to fine-tune BERT for multi-label classification?",
40
  )
41
 
42
+ top_k = st.slider(
43
+ "Nombre de tags à afficher (top_k)",
44
+ 1,
45
+ 20,
46
+ 5,
47
+ key="topk_single",
48
+ )
49
 
50
  if st.button("Prédire", key="predict_single"):
51
  if not question.strip():
52
  st.warning("Merci d'entrer une question.")
53
  else:
54
+ with st.spinner(f"Prédiction en cours avec {selected_model}..."):
55
+ tags = predict_proba(
56
+ question,
57
+ top_k=top_k,
58
+ model_name=selected_model,
59
+ )
60
 
61
  if not tags:
62
  st.warning("Pas de tags prédits.")
 
89
  )
90
 
91
  top_k_batch = st.slider(
92
+ "Nombre de tags à considérer (pour choisir le meilleur)",
93
  1,
94
  20,
95
  5,
 
100
  if df[text_column].isnull().all():
101
  st.error("La colonne choisie ne contient pas de texte.")
102
  else:
103
+ preds_best_tag = []
104
+ preds_best_score = []
105
+
106
+ with st.spinner(f"Prédiction batch avec {selected_model}..."):
107
  for text in df[text_column].fillna(""):
108
+ s = str(text).strip()
109
+ if not s:
110
+ preds_best_tag.append(None)
111
+ preds_best_score.append(None)
112
  continue
113
+
114
+ tags = predict_proba(
115
+ s,
116
+ top_k=top_k_batch,
117
+ model_name=selected_model,
118
+ )
119
+
120
  if len(tags) == 0:
121
+ preds_best_tag.append(None)
122
+ preds_best_score.append(None)
123
  else:
124
  best = tags[0]
125
+ preds_best_tag.append(best["label"])
126
+ preds_best_score.append(best["score"])
 
127
 
128
+ df["predicted_tag"] = preds_best_tag
129
+ df["predicted_score"] = preds_best_score
130
 
131
  st.subheader("Résultats enrichis")
132
  st.dataframe(df.head())