airevartis commited on
Commit
f55049d
·
verified ·
1 Parent(s): eb6e3d8

Upload post_finetune_evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. post_finetune_evaluation.py +412 -0
post_finetune_evaluation.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Post fine-tuning evaluation on Hugging Face infrastructure
4
+ """
5
+ import torch
6
+ import json
7
+ import os
8
+ from transformers import (
9
+ AutoTokenizer,
10
+ AutoModelForCausalLM,
11
+ pipeline,
12
+ BitsAndBytesConfig
13
+ )
14
+ from datasets import load_dataset
15
+ import numpy as np
16
+ from typing import Dict, List, Tuple
17
+ import logging
18
+ import re
19
+ from pathlib import Path
20
+
21
+ # Set up logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class HFPostFineTuneEvaluator:
27
+ def __init__(self):
28
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ logger.info(f"Using device: {self.device}")
30
+
31
+ # Model configurations
32
+ self.models = {
33
+ "biomistral_7b": "BioMistral/BioMistral-7B",
34
+ "qwen3_7b": "Qwen/Qwen2.5-7B-Instruct",
35
+ "meditron_7b": "epfl-llm/meditron-7b",
36
+ "internist_7b": "internistai/internist-7b"
37
+ }
38
+
39
+ # Quantization config
40
+ self.quantization_config = BitsAndBytesConfig(
41
+ load_in_4bit=True,
42
+ bnb_4bit_compute_dtype=torch.float16,
43
+ bnb_4bit_use_double_quant=True,
44
+ bnb_4bit_quant_type="nf4"
45
+ )
46
+
47
+ def load_finetuned_model(self, model_name: str) -> Tuple:
48
+ """Load fine-tuned model from HF Hub"""
49
+ logger.info(f"Loading fine-tuned model: {model_name}")
50
+
51
+ try:
52
+ # Try to load from HF Hub first
53
+ finetuned_repo = f"medical-{model_name}-finetuned"
54
+
55
+ try:
56
+ tokenizer = AutoTokenizer.from_pretrained(
57
+ finetuned_repo,
58
+ trust_remote_code=True
59
+ )
60
+
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ finetuned_repo,
63
+ quantization_config=self.quantization_config if self.device == "cuda" else None,
64
+ device_map="auto" if self.device == "cuda" else None,
65
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
66
+ trust_remote_code=True
67
+ )
68
+
69
+ logger.info(f"Successfully loaded fine-tuned {model_name}")
70
+ return model, tokenizer, True
71
+
72
+ except Exception as e:
73
+ logger.warning(f"Could not load fine-tuned model from HF Hub: {e}")
74
+ logger.info(f"Loading base model {model_name} instead")
75
+
76
+ # Fallback to base model
77
+ tokenizer = AutoTokenizer.from_pretrained(
78
+ self.models[model_name],
79
+ trust_remote_code=True
80
+ )
81
+
82
+ model = AutoModelForCausalLM.from_pretrained(
83
+ self.models[model_name],
84
+ quantization_config=self.quantization_config if self.device == "cuda" else None,
85
+ device_map="auto" if self.device == "cuda" else None,
86
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
87
+ trust_remote_code=True
88
+ )
89
+
90
+ return model, tokenizer, False
91
+
92
+ except Exception as e:
93
+ logger.error(f"Failed to load {model_name}: {e}")
94
+ return None, None, False
95
+
96
+ def create_prompt(self, question: str, options: List[str], model_name: str) -> str:
97
+ """Create prompt for different model types"""
98
+ options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
99
+
100
+ if "qwen" in model_name.lower():
101
+ return f"""<|im_start|>user
102
+ {question}
103
+
104
+ {options_text}
105
+
106
+ Please select the correct answer (A, B, C, D, or E).<|im_end|>
107
+ <|im_start|>assistant
108
+ The correct answer is"""
109
+
110
+ elif "mistral" in model_name.lower() or "biomistral" in model_name.lower():
111
+ return f"""<s>[INST] {question}
112
+
113
+ {options_text}
114
+
115
+ Please select the correct answer (A, B, C, D, or E). [/INST] The correct answer is"""
116
+
117
+ else:
118
+ # Generic format
119
+ return f"""Question: {question}
120
+
121
+ {options_text}
122
+
123
+ Answer:"""
124
+
125
+ def extract_answer(self, text: str) -> str:
126
+ """Extract answer from model output"""
127
+ patterns = [
128
+ r'[Tt]he correct answer is ([A-E])',
129
+ r'[Aa]nswer: ([A-E])',
130
+ r'([A-E])\.',
131
+ r'^([A-E])\s*$'
132
+ ]
133
+
134
+ for pattern in patterns:
135
+ match = re.search(pattern, text)
136
+ if match:
137
+ return match.group(1)
138
+
139
+ match = re.search(r'([A-E])', text)
140
+ if match:
141
+ return match.group(1)
142
+
143
+ return "A"
144
+
145
+ def evaluate_model(self, model_name: str, test_dataset) -> Dict:
146
+ """Evaluate a single model on the test dataset"""
147
+ logger.info(f"Evaluating {model_name}")
148
+
149
+ model, tokenizer, is_finetuned = self.load_finetuned_model(model_name)
150
+ if model is None or tokenizer is None:
151
+ return {"error": f"Failed to load {model_name}"}
152
+
153
+ # Create generation pipeline
154
+ generator = pipeline(
155
+ "text-generation",
156
+ model=model,
157
+ tokenizer=tokenizer,
158
+ max_new_tokens=50,
159
+ temperature=0.1,
160
+ do_sample=False,
161
+ pad_token_id=tokenizer.eos_token_id
162
+ )
163
+
164
+ results = []
165
+ correct = 0
166
+ total = len(test_dataset)
167
+
168
+ logger.info(f"Running evaluation on {total} examples")
169
+
170
+ for i, example in enumerate(test_dataset):
171
+ try:
172
+ # Create prompt
173
+ prompt = self.create_prompt(
174
+ example['question'],
175
+ example['options'],
176
+ model_name
177
+ )
178
+
179
+ # Generate response
180
+ response = generator(prompt, return_full_text=False)
181
+ generated_text = response[0]['generated_text']
182
+
183
+ # Extract answer
184
+ predicted_answer = self.extract_answer(generated_text)
185
+ true_answer = example['answer']
186
+
187
+ is_correct = predicted_answer == true_answer
188
+ if is_correct:
189
+ correct += 1
190
+
191
+ results.append({
192
+ 'question_id': i,
193
+ 'question': example['question'],
194
+ 'options': example['options'],
195
+ 'true_answer': true_answer,
196
+ 'predicted_answer': predicted_answer,
197
+ 'generated_text': generated_text,
198
+ 'is_correct': is_correct
199
+ })
200
+
201
+ except Exception as e:
202
+ logger.error(f"Error processing example {i}: {e}")
203
+ results.append({
204
+ 'question_id': i,
205
+ 'error': str(e),
206
+ 'is_correct': False
207
+ })
208
+
209
+ # Calculate metrics
210
+ accuracy = correct / total if total > 0 else 0
211
+
212
+ # Calculate per-option accuracy
213
+ option_accuracies = {}
214
+ for option in ['A', 'B', 'C', 'D', 'E']:
215
+ option_correct = sum(1 for r in results if r.get('true_answer') == option and r.get('is_correct', False))
216
+ option_total = sum(1 for r in results if r.get('true_answer') == option)
217
+ option_accuracies[option] = option_correct / option_total if option_total > 0 else 0
218
+
219
+ metrics = {
220
+ 'model_name': f"{model_name}_finetuned" if is_finetuned else f"{model_name}_base",
221
+ 'is_finetuned': is_finetuned,
222
+ 'total_examples': total,
223
+ 'correct_predictions': correct,
224
+ 'accuracy': accuracy,
225
+ 'option_accuracies': option_accuracies
226
+ }
227
+
228
+ logger.info(f"{model_name} ({'finetuned' if is_finetuned else 'base'}) - Accuracy: {accuracy:.4f}")
229
+
230
+ # Clean up memory
231
+ del model, tokenizer, generator
232
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
233
+
234
+ return metrics
235
+
236
+ def run_evaluation(self, test_dataset) -> Dict:
237
+ """Run evaluation on all models"""
238
+ results = {}
239
+
240
+ for model_name in self.models.keys():
241
+ logger.info(f"Starting evaluation for {model_name}")
242
+ results[model_name] = self.evaluate_model(model_name, test_dataset)
243
+
244
+ return results
245
+
246
+ def compare_with_baseline(self, post_results: Dict, baseline_file: str = "/tmp/zero_shot_results.json") -> Dict:
247
+ """Compare with baseline zero-shot results"""
248
+ try:
249
+ with open(baseline_file, 'r') as f:
250
+ baseline_results = json.load(f)
251
+ except FileNotFoundError:
252
+ logger.warning("Baseline results not found, skipping comparison")
253
+ return {}
254
+
255
+ comparison = {}
256
+
257
+ for model_name, post_result in post_results.items():
258
+ if 'error' in post_result:
259
+ continue
260
+
261
+ baseline_key = model_name.replace('_finetuned', '')
262
+ if baseline_key in baseline_results and 'error' not in baseline_results[baseline_key]:
263
+ baseline_accuracy = baseline_results[baseline_key]['accuracy']
264
+ post_accuracy = post_result['accuracy']
265
+
266
+ improvement = post_accuracy - baseline_accuracy
267
+ relative_improvement = (improvement / baseline_accuracy * 100) if baseline_accuracy > 0 else 0
268
+
269
+ comparison[model_name] = {
270
+ 'baseline_accuracy': baseline_accuracy,
271
+ 'post_accuracy': post_accuracy,
272
+ 'improvement': improvement,
273
+ 'relative_improvement_pct': relative_improvement
274
+ }
275
+
276
+ return comparison
277
+
278
+ def save_results(self, results: Dict, comparison: Dict, output_path: str = "/tmp/post_finetune_results.json"):
279
+ """Save evaluation results"""
280
+ # Prepare serializable results
281
+ serializable_results = {}
282
+ for model_name, result in results.items():
283
+ if 'error' not in result:
284
+ serializable_results[model_name] = {
285
+ 'model_name': result['model_name'],
286
+ 'is_finetuned': result['is_finetuned'],
287
+ 'total_examples': result['total_examples'],
288
+ 'correct_predictions': result['correct_predictions'],
289
+ 'accuracy': result['accuracy'],
290
+ 'option_accuracies': result['option_accuracies']
291
+ }
292
+
293
+ # Add comparison data
294
+ output_data = {
295
+ 'post_finetune_results': serializable_results,
296
+ 'comparison_with_baseline': comparison
297
+ }
298
+
299
+ with open(output_path, 'w') as f:
300
+ json.dump(output_data, f, indent=2)
301
+
302
+ logger.info(f"Results saved to {output_path}")
303
+ return output_path
304
+
305
+
306
+ def main():
307
+ """Main function for HF post-fine-tuning evaluation job"""
308
+ logger.info("Starting post fine-tuning evaluation on Hugging Face infrastructure")
309
+
310
+ # Load MedQA dataset
311
+ logger.info("Loading MedQA dataset...")
312
+ try:
313
+ dataset = load_dataset("bigbio/med_qa")
314
+ except:
315
+ try:
316
+ dataset = load_dataset("medqa")
317
+ except:
318
+ logger.error("Could not load MedQA dataset")
319
+ return
320
+
321
+ def process_example(example):
322
+ if 'question' in example:
323
+ question = example['question']
324
+ elif 'text' in example:
325
+ question = example['text']
326
+ else:
327
+ question = example['input']
328
+
329
+ if 'options' in example:
330
+ options = example['options']
331
+ elif 'choices' in example:
332
+ options = example['choices']
333
+ else:
334
+ options = []
335
+ for i in range(5):
336
+ key = f'option_{i}' if f'option_{i}' in example else f'choice_{i}'
337
+ if key in example:
338
+ options.append(example[key])
339
+
340
+ if 'answer' in example:
341
+ answer = example['answer']
342
+ elif 'label' in example:
343
+ answer = example['label']
344
+ else:
345
+ answer = example['output']
346
+
347
+ return {
348
+ 'question': question,
349
+ 'options': options,
350
+ 'answer': answer
351
+ }
352
+
353
+ test_dataset = dataset['test'].map(process_example)
354
+ logger.info(f"Processed {len(test_dataset)} test examples")
355
+
356
+ # Initialize evaluator
357
+ evaluator = HFPostFineTuneEvaluator()
358
+
359
+ # Run evaluation
360
+ logger.info("Starting post fine-tuning evaluation...")
361
+ results = evaluator.run_evaluation(test_dataset)
362
+
363
+ # Compare with baseline
364
+ comparison = evaluator.compare_with_baseline(results)
365
+
366
+ # Save results
367
+ output_path = evaluator.save_results(results, comparison)
368
+
369
+ # Print summary
370
+ print("\n" + "="*60)
371
+ print("POST FINE-TUNING EVALUATION RESULTS")
372
+ print("="*60)
373
+
374
+ for model_name, result in results.items():
375
+ if 'error' not in result:
376
+ status = "finetuned" if result['is_finetuned'] else "base"
377
+ print(f"{model_name} ({status}): {result['accuracy']:.4f} accuracy")
378
+
379
+ if comparison:
380
+ print("\n" + "="*60)
381
+ print("IMPROVEMENT ANALYSIS")
382
+ print("="*60)
383
+ for model_name, comp in comparison.items():
384
+ print(f"{model_name}: {comp['baseline_accuracy']:.4f} → {comp['post_accuracy']:.4f} ({comp['relative_improvement_pct']:+.2f}%)")
385
+
386
+ # Upload results to HF Hub
387
+ try:
388
+ from huggingface_hub import HfApi
389
+ api = HfApi()
390
+
391
+ repo_name = "medical-benchmark-results"
392
+ try:
393
+ api.create_repo(repo_name, exist_ok=True)
394
+ except:
395
+ pass
396
+
397
+ api.upload_file(
398
+ path_or_fileobj=output_path,
399
+ path_in_repo="post_finetune_evaluation.json",
400
+ repo_id=repo_name,
401
+ repo_type="dataset"
402
+ )
403
+ logger.info(f"Results uploaded to {repo_name}/post_finetune_evaluation.json")
404
+
405
+ except Exception as e:
406
+ logger.warning(f"Could not upload results to HF Hub: {e}")
407
+
408
+ logger.info("Post fine-tuning evaluation completed!")
409
+
410
+
411
+ if __name__ == "__main__":
412
+ main()