ligaments-dev commited on
Commit
18c6ce2
ยท
verified ยท
1 Parent(s): ce462eb

Update evaluation script with new token

Browse files
Files changed (1) hide show
  1. model_evaluation.py +366 -0
model_evaluation.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "transformers>=4.40.0",
4
+ # "datasets>=2.18.0",
5
+ # "torch>=2.0.0",
6
+ # "rouge-score>=0.1.2",
7
+ # "evaluate>=0.4.0",
8
+ # "numpy>=1.24.0",
9
+ # "pandas>=2.0.0",
10
+ # "scikit-learn>=1.3.0",
11
+ # "huggingface-hub>=0.20.0",
12
+ # "accelerate>=0.27.0",
13
+ # "trackio"
14
+ # ]
15
+ # ///
16
+
17
+ import os
18
+ import json
19
+ import pandas as pd
20
+ import numpy as np
21
+ from datetime import datetime
22
+ from datasets import load_dataset
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
24
+ from rouge_score import rouge_scorer
25
+ from sklearn.metrics import f1_score
26
+ import re
27
+ import trackio
28
+ from huggingface_hub import HfApi, upload_file
29
+ import torch
30
+
31
+ def normalize_text(text):
32
+ """Normalize text for comparison"""
33
+ if not isinstance(text, str):
34
+ return ""
35
+ # Remove extra whitespace and normalize
36
+ text = re.sub(r'\s+', ' ', text.strip())
37
+ return text.lower()
38
+
39
+ def compute_exact_match(pred, true):
40
+ """Compute exact match score"""
41
+ return float(normalize_text(pred) == normalize_text(true))
42
+
43
+ def compute_f1_score(pred, true):
44
+ """Compute token-level F1 score"""
45
+ pred_tokens = normalize_text(pred).split()
46
+ true_tokens = normalize_text(true).split()
47
+
48
+ if len(pred_tokens) == 0 and len(true_tokens) == 0:
49
+ return 1.0
50
+ if len(pred_tokens) == 0 or len(true_tokens) == 0:
51
+ return 0.0
52
+
53
+ # Convert to sets for intersection
54
+ pred_set = set(pred_tokens)
55
+ true_set = set(true_tokens)
56
+
57
+ if len(pred_set) == 0 and len(true_set) == 0:
58
+ return 1.0
59
+
60
+ intersection = pred_set.intersection(true_set)
61
+ precision = len(intersection) / len(pred_set) if pred_set else 0
62
+ recall = len(intersection) / len(true_set) if true_set else 0
63
+
64
+ if precision + recall == 0:
65
+ return 0.0
66
+
67
+ f1 = 2 * (precision * recall) / (precision + recall)
68
+ return f1
69
+
70
+ def compute_rouge_l(pred, true):
71
+ """Compute ROUGE-L score"""
72
+ scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
73
+ scores = scorer.score(normalize_text(true), normalize_text(pred))
74
+ return scores['rougeL'].fmeasure
75
+
76
+ def evaluate_model():
77
+ # Initialize Trackio
78
+ trackio.init()
79
+
80
+ print("๐Ÿš€ Starting model evaluation...")
81
+
82
+ # Configuration
83
+ model_name = "ligaments-enterprise/llama3.2-1b-instruct-sec-finetuned"
84
+ dataset_name = "ligaments-enterprise/sec-data"
85
+
86
+ print(f"๐Ÿ“Š Loading dataset: {dataset_name}")
87
+ try:
88
+ # Try to load the dataset
89
+ dataset = load_dataset(dataset_name, split="train")
90
+ print(f"โœ… Dataset loaded successfully. Size: {len(dataset)}")
91
+ except Exception as e:
92
+ print(f"โŒ Error loading dataset: {e}")
93
+ # Try different splits
94
+ try:
95
+ dataset = load_dataset(dataset_name)
96
+ if isinstance(dataset, dict):
97
+ # Use the first available split
98
+ split_name = list(dataset.keys())[0]
99
+ dataset = dataset[split_name]
100
+ print(f"โœ… Using split '{split_name}'. Size: {len(dataset)}")
101
+ except Exception as e2:
102
+ print(f"โŒ Failed to load dataset: {e2}")
103
+ return
104
+
105
+ # Inspect dataset structure
106
+ print(f"๐Ÿ“‹ Dataset columns: {dataset.column_names}")
107
+ print(f"๐Ÿ“‹ First example: {dataset[0]}")
108
+
109
+ # Determine input/output columns
110
+ possible_input_cols = ['prompt', 'input', 'question', 'instruction', 'text']
111
+ possible_output_cols = ['response', 'output', 'answer', 'completion', 'target']
112
+
113
+ input_col = None
114
+ output_col = None
115
+
116
+ for col in possible_input_cols:
117
+ if col in dataset.column_names:
118
+ input_col = col
119
+ break
120
+
121
+ for col in possible_output_cols:
122
+ if col in dataset.column_names:
123
+ output_col = col
124
+ break
125
+
126
+ # Handle messages format
127
+ if 'messages' in dataset.column_names:
128
+ print("๐Ÿ“‹ Detected messages format, extracting prompts and responses...")
129
+ def extract_from_messages(example):
130
+ messages = example['messages']
131
+ if isinstance(messages, list) and len(messages) >= 2:
132
+ # Find the last user message and assistant response
133
+ user_msg = None
134
+ assistant_msg = None
135
+ for msg in messages:
136
+ if msg.get('role') == 'user':
137
+ user_msg = msg.get('content', '')
138
+ elif msg.get('role') == 'assistant':
139
+ assistant_msg = msg.get('content', '')
140
+
141
+ return {
142
+ 'input_text': user_msg or '',
143
+ 'target_text': assistant_msg or ''
144
+ }
145
+ return {'input_text': '', 'target_text': ''}
146
+
147
+ dataset = dataset.map(extract_from_messages)
148
+ input_col = 'input_text'
149
+ output_col = 'target_text'
150
+
151
+ if not input_col or not output_col:
152
+ print(f"โŒ Could not identify input/output columns. Available: {dataset.column_names}")
153
+ return
154
+
155
+ print(f"โœ… Using input column: {input_col}, output column: {output_col}")
156
+
157
+ print(f"๐Ÿค– Loading model: {model_name}")
158
+ try:
159
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
160
+ model = AutoModelForCausalLM.from_pretrained(
161
+ model_name,
162
+ torch_dtype=torch.float16,
163
+ device_map="auto",
164
+ trust_remote_code=True
165
+ )
166
+
167
+ # Set pad token if not set
168
+ if tokenizer.pad_token is None:
169
+ tokenizer.pad_token = tokenizer.eos_token
170
+
171
+ print("โœ… Model loaded successfully")
172
+ except Exception as e:
173
+ print(f"โŒ Error loading model: {e}")
174
+ return
175
+
176
+ # Create text generation pipeline
177
+ generator = pipeline(
178
+ "text-generation",
179
+ model=model,
180
+ tokenizer=tokenizer,
181
+ torch_dtype=torch.float16,
182
+ device_map="auto"
183
+ )
184
+
185
+ # Limit evaluation to reasonable size for demonstration
186
+ eval_size = min(100, len(dataset))
187
+ eval_dataset = dataset.select(range(eval_size))
188
+ print(f"๐Ÿ“Š Evaluating on {eval_size} samples...")
189
+
190
+ results = []
191
+
192
+ for i, example in enumerate(eval_dataset):
193
+ if i % 10 == 0:
194
+ print(f"๐Ÿ“ˆ Processing sample {i+1}/{eval_size}")
195
+
196
+ input_text = example[input_col]
197
+ target_text = example[output_col]
198
+
199
+ if not input_text or not target_text:
200
+ continue
201
+
202
+ # Generate prediction
203
+ try:
204
+ # Format prompt appropriately
205
+ if not input_text.strip().endswith(('?', '.', '!', ':')):
206
+ formatted_prompt = f"{input_text.strip()}:"
207
+ else:
208
+ formatted_prompt = input_text.strip()
209
+
210
+ generated = generator(
211
+ formatted_prompt,
212
+ max_new_tokens=256,
213
+ do_sample=False, # Deterministic for evaluation
214
+ temperature=0.1,
215
+ pad_token_id=tokenizer.eos_token_id,
216
+ return_full_text=False
217
+ )
218
+
219
+ prediction = generated[0]['generated_text'].strip()
220
+
221
+ # Compute metrics
222
+ exact_match = compute_exact_match(prediction, target_text)
223
+ f1 = compute_f1_score(prediction, target_text)
224
+ rouge_l = compute_rouge_l(prediction, target_text)
225
+
226
+ # Error analysis
227
+ error_type = "correct" if exact_match == 1.0 else "incorrect"
228
+ if exact_match == 0 and f1 > 0.5:
229
+ error_type = "partial_match"
230
+ elif exact_match == 0 and rouge_l > 0.3:
231
+ error_type = "semantic_similarity"
232
+ elif len(prediction.split()) > len(target_text.split()) * 2:
233
+ error_type = "too_verbose"
234
+ elif len(prediction.split()) < len(target_text.split()) * 0.5:
235
+ error_type = "too_brief"
236
+
237
+ result = {
238
+ 'sample_id': i,
239
+ 'input': input_text,
240
+ 'target': target_text,
241
+ 'prediction': prediction,
242
+ 'exact_match': exact_match,
243
+ 'f1_score': f1,
244
+ 'rouge_l': rouge_l,
245
+ 'error_type': error_type,
246
+ 'input_length': len(input_text.split()),
247
+ 'target_length': len(target_text.split()),
248
+ 'prediction_length': len(prediction.split())
249
+ }
250
+
251
+ results.append(result)
252
+
253
+ except Exception as e:
254
+ print(f"โš ๏ธ Error processing sample {i}: {e}")
255
+ continue
256
+
257
+ if not results:
258
+ print("โŒ No results generated")
259
+ return
260
+
261
+ # Compute summary statistics
262
+ df_results = pd.DataFrame(results)
263
+
264
+ summary_metrics = {
265
+ 'evaluation_timestamp': datetime.now().isoformat(),
266
+ 'model_name': model_name,
267
+ 'dataset_name': dataset_name,
268
+ 'total_samples': len(results),
269
+ 'exact_match_avg': df_results['exact_match'].mean(),
270
+ 'f1_score_avg': df_results['f1_score'].mean(),
271
+ 'rouge_l_avg': df_results['rouge_l'].mean(),
272
+ 'exact_match_std': df_results['exact_match'].std(),
273
+ 'f1_score_std': df_results['f1_score'].std(),
274
+ 'rouge_l_std': df_results['rouge_l'].std(),
275
+ 'perfect_matches': int(df_results['exact_match'].sum()),
276
+ 'perfect_match_rate': df_results['exact_match'].mean()
277
+ }
278
+
279
+ # Error analysis summary
280
+ error_analysis = df_results['error_type'].value_counts().to_dict()
281
+ summary_metrics['error_breakdown'] = error_analysis
282
+
283
+ # Performance by length buckets
284
+ df_results['target_length_bucket'] = pd.cut(
285
+ df_results['target_length'],
286
+ bins=[0, 10, 25, 50, 100, float('inf')],
287
+ labels=['very_short', 'short', 'medium', 'long', 'very_long']
288
+ )
289
+
290
+ length_performance = df_results.groupby('target_length_bucket')[['exact_match', 'f1_score', 'rouge_l']].mean().to_dict()
291
+ summary_metrics['performance_by_length'] = length_performance
292
+
293
+ print("\n๐Ÿ“Š EVALUATION RESULTS:")
294
+ print(f"Total Samples: {summary_metrics['total_samples']}")
295
+ print(f"Exact Match: {summary_metrics['exact_match_avg']:.4f} ยฑ {summary_metrics['exact_match_std']:.4f}")
296
+ print(f"F1 Score: {summary_metrics['f1_score_avg']:.4f} ยฑ {summary_metrics['f1_score_std']:.4f}")
297
+ print(f"ROUGE-L: {summary_metrics['rouge_l_avg']:.4f} ยฑ {summary_metrics['rouge_l_std']:.4f}")
298
+ print(f"Perfect Matches: {summary_metrics['perfect_matches']}/{summary_metrics['total_samples']} ({summary_metrics['perfect_match_rate']:.2%})")
299
+
300
+ print("\n๐Ÿ” Error Breakdown:")
301
+ for error_type, count in error_analysis.items():
302
+ print(f" {error_type}: {count} ({count/len(results):.2%})")
303
+
304
+ # Save results locally first
305
+ os.makedirs('eval_results', exist_ok=True)
306
+
307
+ # Save detailed results
308
+ df_results.to_csv('eval_results/detailed_results.csv', index=False)
309
+
310
+ # Save summary metrics
311
+ with open('eval_results/summary_metrics.json', 'w') as f:
312
+ json.dump(summary_metrics, f, indent=2, default=str)
313
+
314
+ # Save top errors for analysis
315
+ worst_samples = df_results.nsmallest(10, 'f1_score')[['sample_id', 'input', 'target', 'prediction', 'f1_score', 'error_type']]
316
+ worst_samples.to_csv('eval_results/worst_predictions.csv', index=False)
317
+
318
+ # Save best samples
319
+ best_samples = df_results.nlargest(10, 'f1_score')[['sample_id', 'input', 'target', 'prediction', 'f1_score', 'error_type']]
320
+ best_samples.to_csv('eval_results/best_predictions.csv', index=False)
321
+
322
+ print("\n๐Ÿ’พ Results saved locally to eval_results/")
323
+
324
+ # Upload results to model repository
325
+ try:
326
+ print("๐Ÿš€ Uploading results to model repository...")
327
+ api = HfApi()
328
+
329
+ # Upload all result files
330
+ files_to_upload = [
331
+ ('eval_results/summary_metrics.json', 'eval_results/summary_metrics.json'),
332
+ ('eval_results/detailed_results.csv', 'eval_results/detailed_results.csv'),
333
+ ('eval_results/worst_predictions.csv', 'eval_results/worst_predictions.csv'),
334
+ ('eval_results/best_predictions.csv', 'eval_results/best_predictions.csv')
335
+ ]
336
+
337
+ for local_path, repo_path in files_to_upload:
338
+ api.upload_file(
339
+ path_or_fileobj=local_path,
340
+ path_in_repo=repo_path,
341
+ repo_id=model_name,
342
+ commit_message=f"Add evaluation results - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
343
+ token=os.getenv('HF_TOKEN')
344
+ )
345
+ print(f"โœ… Uploaded {repo_path}")
346
+
347
+ print(f"โœ… All evaluation results uploaded to {model_name}")
348
+
349
+ # Log to Trackio
350
+ trackio.log({
351
+ "exact_match": summary_metrics['exact_match_avg'],
352
+ "f1_score": summary_metrics['f1_score_avg'],
353
+ "rouge_l": summary_metrics['rouge_l_avg'],
354
+ "perfect_match_rate": summary_metrics['perfect_match_rate'],
355
+ "total_samples": summary_metrics['total_samples']
356
+ })
357
+
358
+ except Exception as e:
359
+ print(f"โš ๏ธ Warning: Could not upload to repository: {e}")
360
+ print("๐Ÿ’พ Results are saved locally in eval_results/ directory")
361
+
362
+ print("\n๐ŸŽ‰ Evaluation completed successfully!")
363
+ return summary_metrics
364
+
365
+ if __name__ == "__main__":
366
+ evaluate_model()