Mattimax commited on
Commit
19dd1f8
·
verified ·
1 Parent(s): 8d47b06

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import gradio as gr
4
+ from datasets import load_dataset
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import pandas as pd
7
+
8
+
9
+ # =========================
10
+ # Configurazione benchmark
11
+ # =========================
12
+
13
+ MAX_MODELS = 5
14
+ DEFAULT_NUM_SAMPLES = 50 # numero di esempi da usare per il benchmark
15
+
16
+
17
+ def get_device():
18
+ if torch.cuda.is_available():
19
+ return "cuda"
20
+ return "cpu"
21
+
22
+
23
+ def load_boolq_dataset(num_samples=DEFAULT_NUM_SAMPLES):
24
+ """
25
+ Carica un subset del dataset BoolQ.
26
+ BoolQ: domande sì/no con un breve contesto.
27
+ """
28
+ ds = load_dataset("boolq", split="validation")
29
+ if num_samples is not None and num_samples < len(ds):
30
+ ds = ds.select(range(num_samples))
31
+ return ds
32
+
33
+
34
+ def build_boolq_prompt(passage, question):
35
+ """
36
+ Costruisce un prompt generico per LLM per BoolQ.
37
+ Il modello deve rispondere solo 'yes' o 'no'.
38
+ """
39
+ prompt = (
40
+ "You are a question answering system. "
41
+ "Answer strictly with 'yes' or 'no'.\n\n"
42
+ f"Passage: {passage}\n"
43
+ f"Question: {question}\n"
44
+ "Answer:"
45
+ )
46
+ return prompt
47
+
48
+
49
+ def parse_yes_no(output_text):
50
+ """
51
+ Estrae 'yes' o 'no' dall'output del modello.
52
+ Se non è chiaro, restituisce None.
53
+ """
54
+ text = output_text.strip().lower()
55
+ # prendi solo la prima parola
56
+ first = text.split()[0] if text else ""
57
+ if first.startswith("yes"):
58
+ return True
59
+ if first.startswith("no"):
60
+ return False
61
+ return None
62
+
63
+
64
+ def evaluate_model_on_boolq(model_name, num_samples=DEFAULT_NUM_SAMPLES, max_new_tokens=5):
65
+ """
66
+ Esegue il benchmark di un modello su BoolQ.
67
+ Ritorna:
68
+ - accuracy
69
+ - numero di esempi valutati
70
+ - tempo medio per esempio
71
+ """
72
+ device = get_device()
73
+ start_total = time.time()
74
+
75
+ # Caricamento modello e tokenizer
76
+ try:
77
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
78
+ model = AutoModelForCausalLM.from_pretrained(model_name)
79
+ except Exception as e:
80
+ raise RuntimeError(f"Errore nel caricamento del modello '{model_name}': {e}")
81
+
82
+ model.to(device)
83
+ model.eval()
84
+
85
+ ds = load_boolq_dataset(num_samples=num_samples)
86
+
87
+ correct = 0
88
+ total = 0
89
+ times = []
90
+
91
+ for example in ds:
92
+ passage = example["passage"]
93
+ question = example["question"]
94
+ label = example["answer"] # True/False
95
+
96
+ prompt = build_boolq_prompt(passage, question)
97
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
98
+
99
+ t0 = time.time()
100
+ with torch.no_grad():
101
+ output_ids = model.generate(
102
+ **inputs,
103
+ max_new_tokens=max_new_tokens,
104
+ do_sample=False,
105
+ temperature=0.0,
106
+ )
107
+ t1 = time.time()
108
+ gen_text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
109
+
110
+ pred = parse_yes_no(gen_text)
111
+ if pred is not None:
112
+ if pred == label:
113
+ correct += 1
114
+ total += 1
115
+ times.append(t1 - t0)
116
+
117
+ if total == 0:
118
+ accuracy = 0.0
119
+ avg_time = None
120
+ else:
121
+ accuracy = correct / total
122
+ avg_time = sum(times) / len(times) if times else None
123
+
124
+ total_time = time.time() - start_total
125
+
126
+ return {
127
+ "model_name": model_name,
128
+ "num_samples": total,
129
+ "accuracy": accuracy,
130
+ "avg_time_per_sample_sec": avg_time,
131
+ "total_time_sec": total_time,
132
+ }
133
+
134
+
135
+ # =========================
136
+ # Funzioni per la UI
137
+ # =========================
138
+
139
+ def add_model_field(current_count):
140
+ """
141
+ Aumenta il numero di campi modello visibili, fino a MAX_MODELS.
142
+ """
143
+ if current_count < MAX_MODELS:
144
+ current_count += 1
145
+ return current_count
146
+
147
+
148
+ def get_visible_textboxes(model_count):
149
+ """
150
+ Ritorna la visibilità dei 5 campi modello in base a model_count.
151
+ """
152
+ visibility = []
153
+ for i in range(1, MAX_MODELS + 1):
154
+ visibility.append(gr.update(visible=(i <= model_count)))
155
+ return visibility
156
+
157
+
158
+ def run_benchmark_ui(
159
+ model_1,
160
+ model_2,
161
+ model_3,
162
+ model_4,
163
+ model_5,
164
+ model_count,
165
+ num_samples,
166
+ ):
167
+ """
168
+ Funzione chiamata dal pulsante 'Esegui benchmark'.
169
+ Raccoglie i nomi dei modelli, esegue il benchmark e ritorna:
170
+ - tabella risultati
171
+ - log testuale
172
+ """
173
+ # Raccogli i modelli attivi
174
+ model_names = []
175
+ all_models = [model_1, model_2, model_3, model_4, model_5]
176
+ for i in range(model_count):
177
+ name = (all_models[i] or "").strip()
178
+ if name:
179
+ model_names.append(name)
180
+
181
+ if len(model_names) < 2:
182
+ return (
183
+ pd.DataFrame(),
184
+ "Devi specificare almeno due modelli validi."
185
+ )
186
+
187
+ results = []
188
+ logs = []
189
+
190
+ logs.append(f"Avvio benchmark su BoolQ con {num_samples} esempi...")
191
+ logs.append(f"Modelli: {', '.join(model_names)}")
192
+ logs.append("Device: " + get_device())
193
+ logs.append("====================================")
194
+
195
+ for name in model_names:
196
+ logs.append(f"\n[MODELLO] {name}")
197
+ try:
198
+ res = evaluate_model_on_boolq(name, num_samples=num_samples)
199
+ results.append(res)
200
+ logs.append(
201
+ f" - Esempi valutati: {res['num_samples']}\n"
202
+ f" - Accuracy: {res['accuracy']:.3f}\n"
203
+ f" - Tempo medio per esempio (s): "
204
+ f"{res['avg_time_per_sample_sec']:.3f}" if res['avg_time_per_sample_sec'] is not None else "N/A"
205
+ )
206
+ except Exception as e:
207
+ logs.append(f" ERRORE: {e}")
208
+
209
+ if results:
210
+ df = pd.DataFrame(results)
211
+ # Ordina per accuracy decrescente
212
+ df = df.sort_values(by="accuracy", ascending=False)
213
+ else:
214
+ df = pd.DataFrame()
215
+
216
+ log_text = "\n".join(str(l) for l in logs)
217
+ return df, log_text
218
+
219
+
220
+ # =========================
221
+ # Costruzione interfaccia Gradio
222
+ # =========================
223
+
224
+ with gr.Blocks(title="LLM Benchmark Space - BoolQ") as demo:
225
+ gr.Markdown(
226
+ """
227
+ # 🔍 LLM Benchmark Space (BoolQ)
228
+
229
+ Inserisci i nomi dei modelli Hugging Face (es. `meta-llama/Meta-Llama-3-8B-Instruct`)
230
+ e confrontali su un subset del dataset **BoolQ** (domande sì/no).
231
+
232
+ - Minimo **2 modelli**
233
+ - Puoi aggiungere fino a **5 modelli** con il pulsante **"+ Aggiungi modello"**
234
+ - Output: tabella con **accuracy**, numero di esempi e tempi
235
+ """
236
+ )
237
+
238
+ with gr.Row():
239
+ with gr.Column():
240
+ model_count_state = gr.State(value=2)
241
+
242
+ model_1 = gr.Textbox(
243
+ label="Modello 1",
244
+ placeholder="es. meta-llama/Meta-Llama-3-8B-Instruct",
245
+ value="",
246
+ visible=True,
247
+ )
248
+ model_2 = gr.Textbox(
249
+ label="Modello 2",
250
+ placeholder="es. mistralai/Mistral-7B-Instruct-v0.3",
251
+ value="",
252
+ visible=True,
253
+ )
254
+ model_3 = gr.Textbox(
255
+ label="Modello 3",
256
+ placeholder="Modello opzionale",
257
+ value="",
258
+ visible=False,
259
+ )
260
+ model_4 = gr.Textbox(
261
+ label="Modello 4",
262
+ placeholder="Modello opzionale",
263
+ value="",
264
+ visible=False,
265
+ )
266
+ model_5 = gr.Textbox(
267
+ label="Modello 5",
268
+ placeholder="Modello opzionale",
269
+ value="",
270
+ visible=False,
271
+ )
272
+
273
+ add_button = gr.Button("+ Aggiungi modello")
274
+
275
+ num_samples = gr.Slider(
276
+ minimum=10,
277
+ maximum=200,
278
+ step=10,
279
+ value=DEFAULT_NUM_SAMPLES,
280
+ label="Numero di esempi BoolQ da usare",
281
+ )
282
+
283
+ run_button = gr.Button("🚀 Esegui benchmark", variant="primary")
284
+
285
+ with gr.Column():
286
+ results_df = gr.Dataframe(
287
+ headers=[
288
+ "model_name",
289
+ "num_samples",
290
+ "accuracy",
291
+ "avg_time_per_sample_sec",
292
+ "total_time_sec",
293
+ ],
294
+ label="Risultati benchmark",
295
+ interactive=False,
296
+ )
297
+ logs_box = gr.Textbox(
298
+ label="Log esecuzione",
299
+ lines=20,
300
+ interactive=False,
301
+ )
302
+
303
+ # Logica pulsante "+ Aggiungi modello"
304
+ def on_add_model(model_count):
305
+ new_count = add_model_field(model_count)
306
+ visibility_updates = get_visible_textboxes(new_count)
307
+ return [new_count] + visibility_updates
308
+
309
+ add_button.click(
310
+ fn=on_add_model,
311
+ inputs=[model_count_state],
312
+ outputs=[model_count_state, model_1, model_2, model_3, model_4, model_5],
313
+ )
314
+
315
+ # Logica pulsante "Esegui benchmark"
316
+ run_button.click(
317
+ fn=run_benchmark_ui,
318
+ inputs=[
319
+ model_1,
320
+ model_2,
321
+ model_3,
322
+ model_4,
323
+ model_5,
324
+ model_count_state,
325
+ num_samples,
326
+ ],
327
+ outputs=[results_df, logs_box],
328
+ )
329
+
330
+
331
+ if __name__ == "__main__":
332
+ demo.launch()