maxcasado commited on
Commit
883f14c
·
verified ·
1 Parent(s): c9c177e

Update frontend.py

Browse files
Files changed (1) hide show
  1. frontend.py +86 -30
frontend.py CHANGED
@@ -1,43 +1,37 @@
1
- # frontend.py
2
- import os
3
- import requests
4
  import streamlit as st
 
 
5
 
6
- API_URL = os.getenv("API_URL", "http://localhost:8000")
7
 
8
  st.set_page_config(page_title="StackOverflow Tagger", layout="wide")
9
  st.title("🔖 StackOverflow Tag Predictor")
10
 
11
- st.write(
12
- "Entrez une question (titre + éventuellement description) "
13
- "et récupérez les probabilités des tags StackOverflow."
14
- )
15
 
16
- question = st.text_area(
17
- "Question StackOverflow",
18
- height=200,
19
- placeholder="Ex: How to fine-tune BERT for multi-label classification?",
20
- )
21
 
22
- top_k = st.slider("Nombre de tags à afficher (top_k)", 1, 20, 5)
 
 
 
 
23
 
24
- if st.button("Prédire"):
25
- if not question.strip():
26
- st.warning("Merci d'entrer une question.")
27
- else:
28
- try:
 
29
  with st.spinner("Prédiction en cours..."):
30
- resp = requests.post(
31
- f"{API_URL}/predict",
32
- json={"text": question, "top_k": top_k},
33
- timeout=60,
34
- )
35
- resp.raise_for_status()
36
- data = resp.json()
37
- tags = data.get("tags", [])
38
 
39
  if not tags:
40
- st.warning("Aucun tag renvoyé par l'API.")
41
  else:
42
  st.subheader("Résultats")
43
  for t in tags:
@@ -47,5 +41,67 @@ if st.button("Prédire"):
47
  scores = {t["label"]: t["score"] for t in tags}
48
  st.bar_chart(scores)
49
 
50
- except Exception as e:
51
- st.error(f"Erreur lors de l'appel à l'API : {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ from io import StringIO
4
 
5
+ from model_utils import predict_proba
6
 
7
  st.set_page_config(page_title="StackOverflow Tagger", layout="wide")
8
  st.title("🔖 StackOverflow Tag Predictor")
9
 
10
+ tab_single, tab_csv = st.tabs(["Question unique", "CSV batch"])
 
 
 
11
 
12
+ with tab_single:
13
+ st.write(
14
+ "Entrez une question (titre + éventuellement description) "
15
+ "et récupérez les probabilités des tags StackOverflow prédits par le modèle."
16
+ )
17
 
18
+ question = st.text_area(
19
+ "Question StackOverflow",
20
+ height=200,
21
+ placeholder="Ex: How to fine-tune BERT for multi-label classification?",
22
+ )
23
 
24
+ top_k = st.slider("Nombre de tags à afficher (top_k)", 1, 20, 5, key="topk_single")
25
+
26
+ if st.button("Prédire", key="predict_single"):
27
+ if not question.strip():
28
+ st.warning("Merci d'entrer une question.")
29
+ else:
30
  with st.spinner("Prédiction en cours..."):
31
+ tags = predict_proba(question, top_k=top_k)
 
 
 
 
 
 
 
32
 
33
  if not tags:
34
+ st.warning("Pas de tags prédits.")
35
  else:
36
  st.subheader("Résultats")
37
  for t in tags:
 
41
  scores = {t["label"]: t["score"] for t in tags}
42
  st.bar_chart(scores)
43
 
44
+ with tab_csv:
45
+ st.write(
46
+ "Uploade un fichier CSV contenant des questions. "
47
+ "On ajoutera une colonne avec le tag principal prédit pour chaque ligne."
48
+ )
49
+
50
+ uploaded_file = st.file_uploader("Choisir un fichier CSV", type=["csv"])
51
+
52
+ if uploaded_file is not None:
53
+ df = pd.read_csv(uploaded_file)
54
+
55
+ st.write("Aperçu du CSV :")
56
+ st.dataframe(df.head())
57
+
58
+ text_column = st.selectbox(
59
+ "Colonne contenant la question",
60
+ options=list(df.columns),
61
+ )
62
+
63
+ top_k_batch = st.slider(
64
+ "Nombre de tags à considérer pour le batch (pour choisir le meilleur)",
65
+ 1,
66
+ 20,
67
+ 5,
68
+ key="topk_batch",
69
+ )
70
+
71
+ if st.button("Lancer la prédiction sur le CSV"):
72
+ if df[text_column].isnull().all():
73
+ st.error("La colonne choisie ne contient pas de texte.")
74
+ else:
75
+ preds = []
76
+ with st.spinner("Prédiction en cours sur le CSV..."):
77
+ for text in df[text_column].fillna(""):
78
+ if not str(text).strip():
79
+ preds.append({"best_tag": None, "best_score": None})
80
+ continue
81
+ tags = predict_proba(str(text), top_k=top_k_batch)
82
+ if len(tags) == 0:
83
+ preds.append({"best_tag": None, "best_score": None})
84
+ else:
85
+ best = tags[0]
86
+ preds.append(
87
+ {"best_tag": best["label"], "best_score": best["score"]}
88
+ )
89
+
90
+ df["predicted_tag"] = [p["best_tag"] for p in preds]
91
+ df["predicted_score"] = [p["best_score"] for p in preds]
92
+
93
+ st.subheader("Résultats enrichis")
94
+ st.dataframe(df.head())
95
+
96
+ csv_buffer = StringIO()
97
+ df.to_csv(csv_buffer, index=False)
98
+ csv_bytes = csv_buffer.getvalue().encode("utf-8")
99
+
100
+ st.download_button(
101
+ label="📥 Télécharger le CSV avec tags prédits",
102
+ data=csv_bytes,
103
+ file_name="questions_with_tags.csv",
104
+ mime="text/csv",
105
+ )
106
+ else:
107
+ st.info("Uploade un fich