WissamH commited on
Commit
d53ae44
·
1 Parent(s): 1638f0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -6
app.py CHANGED
@@ -147,7 +147,7 @@ st.markdown('<p class="small-note">Astuce: plus le texte est long, plus la class
147
  # ---------------------------
148
  # API_URL comporte l'adresse de l'API pour la prédiction du modèle
149
  # exemple: https://[your_HF_name]-[your_space_name].hf.space/predict
150
- API_URL = os.environ["API_URL"]
151
 
152
  verify = st.button("Vérifier la nouvelle", type="primary", use_container_width=True)
153
 
@@ -180,6 +180,31 @@ def render_indicator(kind: str, main_label: str, conf: float | None):
180
  </div>
181
  """, unsafe_allow_html=True)
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  if verify:
184
  text = (st.session_state.user_text or "").strip()
185
 
@@ -196,9 +221,8 @@ if verify:
196
  response.raise_for_status()
197
  result = response.json()
198
 
199
- prediction = result.get("prediction", "unknown")
200
- conf = result.get("confidence", None)
201
- reasons = result.get("reasons", None)
202
 
203
  # Big indicator + message
204
  if prediction == 2:
@@ -230,8 +254,6 @@ if verify:
230
 
231
  st.markdown("</div>", unsafe_allow_html=True)
232
 
233
-
234
-
235
  except requests.exceptions.RequestException as e:
236
  st.error(f"Une erreur est survenue lors de la vérification de l'article: {e}")
237
 
 
147
  # ---------------------------
148
  # API_URL comporte l'adresse de l'API pour la prédiction du modèle
149
  # exemple: https://[your_HF_name]-[your_space_name].hf.space/predict
150
+ API_URL = os.environ["API_URL"] + '/predict'
151
 
152
  verify = st.button("Vérifier la nouvelle", type="primary", use_container_width=True)
153
 
 
180
  </div>
181
  """, unsafe_allow_html=True)
182
 
183
+ def normalize_api_response(result):
184
+ # Cas scikit-learn : {"prediction": 0}
185
+ if isinstance(result.get("prediction"), int):
186
+ prediction = result["prediction"]
187
+ conf = result.get("confidence", None)
188
+ reasons = result.get("reasons", None)
189
+
190
+ # Cas transformers : {"prediction": "LABEL_2", "score": 0.95}
191
+ elif isinstance(result.get("prediction"), str):
192
+ label_to_int = {
193
+ "LABEL_0": 0,
194
+ "LABEL_1": 1,
195
+ "LABEL_2": 2,
196
+ }
197
+ prediction = label_to_int.get(result["prediction"], -1)
198
+ conf = result.get("score", None)
199
+ reasons = result.get("reasons", None)
200
+
201
+ else:
202
+ prediction = -1 # Valeur par défaut pour "non classé"
203
+ conf = None
204
+ reasons = None
205
+
206
+ return prediction, conf, reasons
207
+
208
  if verify:
209
  text = (st.session_state.user_text or "").strip()
210
 
 
221
  response.raise_for_status()
222
  result = response.json()
223
 
224
+ # Normalise la réponse
225
+ prediction, conf, reasons = normalize_api_response(result)
 
226
 
227
  # Big indicator + message
228
  if prediction == 2:
 
254
 
255
  st.markdown("</div>", unsafe_allow_html=True)
256
 
 
 
257
  except requests.exceptions.RequestException as e:
258
  st.error(f"Une erreur est survenue lors de la vérification de l'article: {e}")
259