Merwan6 commited on
Commit
716abff
·
1 Parent(s): 3eae2ef

modif onglet

Browse files
Files changed (1) hide show
  1. app.py +42 -32
app.py CHANGED
@@ -6,23 +6,21 @@ from scripts.inference import (
6
  base_model_inference,
7
  fine_tuned_inference
8
  )
 
9
 
10
  def predict_with_model(text, model_type):
11
  """
12
- Applique la stratégie de classification sélectionnée sur un texte donné
13
- et retourne la catégorie prédite avec les scores de confiance.
14
 
15
  Args:
16
- text (str): Le texte à analyser (actualité).
17
- model_type (str): Le type de modèle sélectionné ("Zero-shot", "Few-shot", etc.).
18
 
19
  Returns:
20
  tuple:
21
  - str: Catégorie prédite.
22
- - pandas.DataFrame: Tableau des scores de confiance par classe.
23
  """
24
-
25
- #Sélection du modèle d'inférence en fonction du choix utilisateur
26
  if model_type == "Zero-shot":
27
  prediction, scores = zero_shot_inference(text)
28
  elif model_type == "Few-shot":
@@ -34,41 +32,53 @@ def predict_with_model(text, model_type):
34
  else:
35
  return "Modèle inconnu", pd.DataFrame()
36
 
37
- #Convertit les scores (dict) en DataFrame pour affichage dans Gradio
38
  scores_df = pd.DataFrame([
39
  {"Classe": label, "Score": score} for label, score in scores.items()
40
  ])
41
-
42
  return prediction, scores_df
43
 
44
- #Définition de l'interface utilisateur avec Gradio
45
- iface = gr.Interface(
46
- fn=predict_with_model, #Fonction appelée au clic de l'utilisateur
47
- inputs=[
48
- gr.Textbox(
 
 
 
 
 
 
49
  lines=4,
50
- placeholder="Entrez une phrase d'actualité ici...",
51
  label="Texte à classifier"
52
- ),
53
- gr.Radio(
 
54
  choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"],
55
  label="Choisir le modèle",
56
- value="Base model" #Valeur par défaut
57
  )
58
- ],
59
- outputs=[
60
- gr.Label(label="Catégorie prédite"), #Affiche la prédiction principale
61
- gr.BarPlot( #Affiche les scores de confiance
62
- label="Scores de confiance",
63
- x="Classe",
64
- y="Score",
65
- color="Classe"
66
  )
67
- ],
68
- title="Classification AG News (4 stratégies)",
69
- description="Comparer un modèle préentraîné, Zero-shot, Few-shot et Fine-tuned sur AG News"
70
- )
71
 
72
- #Lancement de l'application
 
 
 
 
 
 
 
 
 
 
73
  if __name__ == "__main__":
74
- iface.launch()
 
6
  base_model_inference,
7
  fine_tuned_inference
8
  )
9
+ from pathlib import Path
10
 
11
  def predict_with_model(text, model_type):
12
  """
13
+ Applique une stratégie de classification sur un texte donné.
 
14
 
15
  Args:
16
+ text (str): Le texte d’actualité à analyser.
17
+ model_type (str): Le modèle choisi ("Zero-shot", "Few-shot", etc.).
18
 
19
  Returns:
20
  tuple:
21
  - str: Catégorie prédite.
22
+ - pandas.DataFrame: Score de confiance par classe.
23
  """
 
 
24
  if model_type == "Zero-shot":
25
  prediction, scores = zero_shot_inference(text)
26
  elif model_type == "Few-shot":
 
32
  else:
33
  return "Modèle inconnu", pd.DataFrame()
34
 
35
+ # Convertir le dictionnaire des scores en DataFrame pour affichage
36
  scores_df = pd.DataFrame([
37
  {"Classe": label, "Score": score} for label, score in scores.items()
38
  ])
 
39
  return prediction, scores_df
40
 
41
+ # === Interface Gradio avec deux onglets ===
42
+ with gr.Blocks(title="Classification AG News (4 stratégies)") as app:
43
+
44
+ gr.Markdown("# 📰 Classification de textes AG News")
45
+ gr.Markdown("Cette application compare plusieurs approches NLP pour classer des actualités.")
46
+
47
+ with gr.Tab("🧠 Inférence"):
48
+ gr.Markdown("### ✍️ Entrez une phrase d'actualité à analyser")
49
+
50
+ # Entrées utilisateur
51
+ input_text = gr.Textbox(
52
  lines=4,
53
+ placeholder="Ex: Apple lance un nouveau produit basé sur l'intelligence artificielle...",
54
  label="Texte à classifier"
55
+ )
56
+
57
+ model_choice = gr.Radio(
58
  choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"],
59
  label="Choisir le modèle",
60
+ value="Base model"
61
  )
62
+
63
+ predict_button = gr.Button("🔍 Prédire")
64
+
65
+ # Sorties
66
+ label_output = gr.Label(label="🧾 Catégorie prédite")
67
+ scores_output = gr.BarPlot(
68
+ label="📊 Scores de confiance",
69
+ x="Classe", y="Score", color="Classe"
70
  )
 
 
 
 
71
 
72
+ # Action sur clic bouton
73
+ predict_button.click(
74
+ fn=predict_with_model,
75
+ inputs=[input_text, model_choice],
76
+ outputs=[label_output, scores_output]
77
+ )
78
+
79
+ with gr.Tab("📄 Documentation"):
80
+ gr.Markdown(Path("README.md").read_text())
81
+
82
+ # Lancer l'app
83
  if __name__ == "__main__":
84
+ app.launch()