BenTouss commited on
Commit
a37f4ce
·
verified ·
1 Parent(s): 2c05295

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -12
app.py CHANGED
@@ -2,28 +2,104 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("BenTouss/mdeberta-eurochef")
6
  model = AutoModelForSequenceClassification.from_pretrained("BenTouss/mdeberta-eurochef")
7
 
8
- def get_labels_table(text):
 
 
 
 
 
 
 
 
9
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
 
10
 
11
  with torch.no_grad():
12
  outputs = model(**inputs)
13
- probs = torch.sigmoid(outputs.logits)[0]
14
 
15
- rows = []
16
  for idx, prob in enumerate(probs):
17
- if prob > 0.6:
18
- rows.append([model.config.id2label[idx], float(prob)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- rows.sort(key=lambda x: x[1], reverse=True)
21
- return rows
22
 
23
- demo = gr.Interface(
24
- fn=get_labels_table,
25
- inputs=gr.Textbox(lines=3, label="Text"),
26
- outputs=gr.Dataframe(headers=["label", "score"]),
27
- )
 
28
 
29
  demo.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
+ APP_NAME = "EuroChef"
6
+
7
  tokenizer = AutoTokenizer.from_pretrained("BenTouss/mdeberta-eurochef")
8
  model = AutoModelForSequenceClassification.from_pretrained("BenTouss/mdeberta-eurochef")
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+ model.eval()
13
+
14
+ def predict(text: str, threshold: float = 0.6, top_k: int = 8, only_above: bool = True):
15
+ text = (text or "").strip()
16
+ if not text:
17
+ return "_Colle un message à gauche pour lancer l’analyse._", []
18
+
19
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
20
+ inputs = {k: v.to(device) for k, v in inputs.items()}
21
 
22
  with torch.no_grad():
23
  outputs = model(**inputs)
24
+ probs = torch.sigmoid(outputs.logits)[0].detach().cpu()
25
 
26
+ items = []
27
  for idx, prob in enumerate(probs):
28
+ score = float(prob)
29
+ label = model.config.id2label[idx]
30
+ if (not only_above) or (score >= threshold):
31
+ items.append((label, score))
32
+
33
+ items.sort(key=lambda x: x[1], reverse=True)
34
+ items = items[: max(1, int(top_k))]
35
+
36
+ rows = [[lbl, float(f"{sc:.3f}")] for lbl, sc in items]
37
+
38
+ if rows:
39
+ best_lbl, best_sc = rows[0][0], rows[0][1]
40
+ summary = f"**Top label :** `{best_lbl}` • **score :** `{best_sc}` \n**Results :** {len(rows)} • **threshold :** `{threshold:.2f}`"
41
+ else:
42
+ summary = f"_No label (thresold `{threshold:.2f}`)._"
43
+
44
+ return summary, rows
45
+
46
+
47
+ CSS = """
48
+ #title { margin-bottom: 0.25rem; }
49
+ #subtitle { margin-top: 0; opacity: 0.8; }
50
+ .footer { opacity: 0.7; font-size: 0.85rem; text-align: center; margin-top: 0.75rem; }
51
+ """
52
+
53
+ with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
54
+ gr.Markdown(f"# 🍳 {APP_NAME}", elem_id="title")
55
+ gr.Markdown("Customer support message → labels + scores.", elem_id="subtitle")
56
+
57
+ with gr.Row():
58
+ with gr.Column(scale=6):
59
+ text = gr.Textbox(
60
+ label="Customer support message",
61
+ placeholder="Ex: Bonjour, je n’arrive pas à lancer les vidéos…",
62
+ lines=10,
63
+ )
64
+
65
+ with gr.Row():
66
+ threshold = gr.Slider(0.0, 1.0, value=0.6, step=0.01, label="Threshold")
67
+ top_k = gr.Slider(1, 20, value=8, step=1, label="Top-K")
68
+ only_above = gr.Checkbox(value=True, label="Only ≥ threshold")
69
+
70
+ with gr.Row():
71
+ run = gr.Button("Analyze", variant="primary")
72
+ clear = gr.ClearButton(value="Clear")
73
+
74
+ gr.Examples(
75
+ examples=[
76
+ "Bonjour, je n’arrive pas à lancer les vidéos : écran noir et chargement infini. Je suis Premium mais certaines recettes restent verrouillées…",
77
+ "Je veux annuler mon abonnement mais je ne trouve pas où le faire dans l’app.",
78
+ "Paiement refusé alors que ma carte fonctionne ailleurs. Pouvez-vous vérifier ?",
79
+ "L’application plante dès que je lance une vidéo en Chromecast.",
80
+ ],
81
+ inputs=[text],
82
+ label="Exemples",
83
+ )
84
+
85
+ with gr.Column(scale=6):
86
+ summary = gr.Markdown(label="Summary")
87
+ table = gr.Dataframe(
88
+ headers=["label", "score"],
89
+ datatype=["str", "number"],
90
+ label="Predictions",
91
+ wrap=True,
92
+ interactive=False,
93
+ height=320,
94
+ )
95
 
96
+ gr.Markdown(f"<div class='footer'>Made with ❤️ by Ben • {APP_NAME}</div>")
 
97
 
98
+ run.click(
99
+ fn=predict,
100
+ inputs=[text, threshold, top_k, only_above],
101
+ outputs=[summary, table],
102
+ )
103
+ clear.add([text, summary, table])
104
 
105
  demo.launch()