airevartis commited on
Commit
eb6e3d8
·
verified ·
1 Parent(s): aa965c5

Upload finetune_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. finetune_model.py +331 -0
finetune_model.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fine-tuning script for medical models on Hugging Face infrastructure
4
+ """
5
+ import torch
6
+ import json
7
+ import os
8
+ from transformers import (
9
+ AutoTokenizer,
10
+ AutoModelForCausalLM,
11
+ TrainingArguments,
12
+ Trainer,
13
+ DataCollatorForLanguageModeling
14
+ )
15
+ from datasets import load_dataset
16
+ from peft import LoraConfig, get_peft_model, TaskType
17
+ import numpy as np
18
+ from typing import Dict, List
19
+ import logging
20
+ from pathlib import Path
21
+
22
+ # Set up logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class HFFineTuner:
28
+ def __init__(self, model_name: str):
29
+ self.model_name = model_name
30
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ logger.info(f"Fine-tuning {model_name} on device: {self.device}")
32
+
33
+ # Model configurations
34
+ self.models = {
35
+ "biomistral_7b": "BioMistral/BioMistral-7B",
36
+ "qwen3_7b": "Qwen/Qwen2.5-7B-Instruct",
37
+ "meditron_7b": "epfl-llm/meditron-7b",
38
+ "internist_7b": "internistai/internist-7b"
39
+ }
40
+
41
+ # LoRA configuration
42
+ self.lora_config = LoraConfig(
43
+ task_type=TaskType.CAUSAL_LM,
44
+ r=16,
45
+ lora_alpha=32,
46
+ lora_dropout=0.1,
47
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
48
+ )
49
+
50
+ def load_model_and_tokenizer(self):
51
+ """Load model and tokenizer for fine-tuning"""
52
+ model_path = self.models[self.model_name]
53
+ logger.info(f"Loading {model_path}")
54
+
55
+ # Load tokenizer
56
+ tokenizer = AutoTokenizer.from_pretrained(
57
+ model_path,
58
+ trust_remote_code=True
59
+ )
60
+
61
+ if tokenizer.pad_token is None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+
64
+ # Load model
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ model_path,
67
+ device_map="auto" if self.device == "cuda" else None,
68
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
69
+ trust_remote_code=True
70
+ )
71
+
72
+ # Apply LoRA
73
+ model = get_peft_model(model, self.lora_config)
74
+ model.print_trainable_parameters()
75
+
76
+ return model, tokenizer
77
+
78
+ def load_and_process_dataset(self):
79
+ """Load and process MedQA dataset for training"""
80
+ logger.info("Loading MedQA dataset...")
81
+
82
+ # Load dataset
83
+ try:
84
+ dataset = load_dataset("bigbio/med_qa")
85
+ except:
86
+ try:
87
+ dataset = load_dataset("medqa")
88
+ except:
89
+ logger.error("Could not load MedQA dataset")
90
+ return None
91
+
92
+ def process_example(example):
93
+ # Handle different dataset formats
94
+ if 'question' in example:
95
+ question = example['question']
96
+ elif 'text' in example:
97
+ question = example['text']
98
+ else:
99
+ question = example['input']
100
+
101
+ # Handle multiple choice options
102
+ if 'options' in example:
103
+ options = example['options']
104
+ elif 'choices' in example:
105
+ options = example['choices']
106
+ else:
107
+ options = []
108
+ for i in range(5):
109
+ key = f'option_{i}' if f'option_{i}' in example else f'choice_{i}'
110
+ if key in example:
111
+ options.append(example[key])
112
+
113
+ # Get answer
114
+ if 'answer' in example:
115
+ answer = example['answer']
116
+ elif 'label' in example:
117
+ answer = example['label']
118
+ else:
119
+ answer = example['output']
120
+
121
+ return {
122
+ 'question': question,
123
+ 'options': options,
124
+ 'answer': answer
125
+ }
126
+
127
+ # Process dataset
128
+ processed_dataset = dataset.map(process_example)
129
+
130
+ # Create training prompts
131
+ def create_prompt(example):
132
+ question = example['question']
133
+ options = example['options']
134
+ answer = example['answer']
135
+
136
+ options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
137
+
138
+ if "qwen" in self.model_name.lower():
139
+ prompt = f"""<|im_start|>user
140
+ {question}
141
+
142
+ {options_text}
143
+
144
+ Please select the correct answer (A, B, C, D, or E).<|im_end|>
145
+ <|im_start|>assistant
146
+ The correct answer is {answer}.<|im_end|>"""
147
+ elif "mistral" in self.model_name.lower() or "biomistral" in self.model_name.lower():
148
+ prompt = f"""<s>[INST] {question}
149
+
150
+ {options_text}
151
+
152
+ Please select the correct answer (A, B, C, D, or E). [/INST] The correct answer is {answer}.</s>"""
153
+ else:
154
+ # Generic format
155
+ prompt = f"""Question: {question}
156
+
157
+ {options_text}
158
+
159
+ Answer: {answer}"""
160
+
161
+ return {"text": prompt}
162
+
163
+ # Format for training
164
+ formatted_dataset = processed_dataset.map(create_prompt)
165
+
166
+ # Split into train/validation
167
+ train_val_split = formatted_dataset['train'].train_test_split(test_size=0.2, seed=42)
168
+
169
+ return {
170
+ 'train': train_val_split['train'],
171
+ 'validation': train_val_split['test'],
172
+ 'test': formatted_dataset['test']
173
+ }
174
+
175
+ def tokenize_dataset(self, dataset, tokenizer):
176
+ """Tokenize dataset for training"""
177
+ def tokenize_function(examples):
178
+ tokenized = tokenizer(
179
+ examples['text'],
180
+ truncation=True,
181
+ padding=False,
182
+ max_length=2048,
183
+ return_tensors=None
184
+ )
185
+ tokenized['labels'] = tokenized['input_ids'].copy()
186
+ return tokenized
187
+
188
+ tokenized_dataset = dataset.map(
189
+ tokenize_function,
190
+ batched=True,
191
+ remove_columns=dataset['train'].column_names
192
+ )
193
+
194
+ return tokenized_dataset
195
+
196
+ def fine_tune(self):
197
+ """Main fine-tuning function"""
198
+ logger.info(f"Starting fine-tuning for {self.model_name}")
199
+
200
+ # Load model and tokenizer
201
+ model, tokenizer = self.load_model_and_tokenizer()
202
+
203
+ # Load and process dataset
204
+ dataset = self.load_and_process_dataset()
205
+ if dataset is None:
206
+ return
207
+
208
+ # Tokenize dataset
209
+ tokenized_dataset = self.tokenize_dataset(dataset, tokenizer)
210
+
211
+ # Training arguments
212
+ training_args = TrainingArguments(
213
+ output_dir=f"/tmp/{self.model_name}_finetuned",
214
+ num_train_epochs=3,
215
+ per_device_train_batch_size=4,
216
+ per_device_eval_batch_size=8,
217
+ gradient_accumulation_steps=4,
218
+ learning_rate=2e-5,
219
+ weight_decay=0.01,
220
+ warmup_ratio=0.1,
221
+ logging_steps=10,
222
+ eval_steps=100,
223
+ save_steps=500,
224
+ save_total_limit=2,
225
+ load_best_model_at_end=True,
226
+ metric_for_best_model="eval_loss",
227
+ greater_is_better=False,
228
+ fp16=True,
229
+ evaluation_strategy="steps",
230
+ save_strategy="steps",
231
+ report_to="none",
232
+ remove_unused_columns=False,
233
+ )
234
+
235
+ # Data collator
236
+ data_collator = DataCollatorForLanguageModeling(
237
+ tokenizer=tokenizer,
238
+ mlm=False,
239
+ )
240
+
241
+ # Trainer
242
+ trainer = Trainer(
243
+ model=model,
244
+ args=training_args,
245
+ train_dataset=tokenized_dataset['train'],
246
+ eval_dataset=tokenized_dataset['validation'],
247
+ data_collator=data_collator,
248
+ )
249
+
250
+ # Train
251
+ logger.info("Starting training...")
252
+ trainer.train()
253
+
254
+ # Save model
255
+ output_dir = f"/tmp/{self.model_name}_finetuned"
256
+ trainer.save_model(output_dir)
257
+ tokenizer.save_pretrained(output_dir)
258
+
259
+ # Save training metrics
260
+ training_metrics = trainer.evaluate()
261
+ with open(f"{output_dir}/training_metrics.json", 'w') as f:
262
+ json.dump(training_metrics, f, indent=2)
263
+
264
+ logger.info(f"Fine-tuning completed for {self.model_name}")
265
+ logger.info(f"Model saved to: {output_dir}")
266
+
267
+ # Upload to HF Hub
268
+ try:
269
+ from huggingface_hub import HfApi
270
+ api = HfApi()
271
+
272
+ # Create repository for fine-tuned model
273
+ repo_name = f"medical-{self.model_name}-finetuned"
274
+ try:
275
+ api.create_repo(repo_name, exist_ok=True)
276
+ except:
277
+ pass
278
+
279
+ # Upload model files
280
+ api.upload_folder(
281
+ folder_path=output_dir,
282
+ repo_id=repo_name,
283
+ repo_type="model"
284
+ )
285
+
286
+ logger.info(f"Fine-tuned model uploaded to {repo_name}")
287
+
288
+ # Upload training metrics
289
+ api.upload_file(
290
+ path_or_fileobj=f"{output_dir}/training_metrics.json",
291
+ path_in_repo="training_metrics.json",
292
+ repo_id=repo_name,
293
+ repo_type="model"
294
+ )
295
+
296
+ except Exception as e:
297
+ logger.warning(f"Could not upload model to HF Hub: {e}")
298
+
299
+ return output_dir
300
+
301
+
302
+ def main():
303
+ """Main function for HF fine-tuning job"""
304
+ import sys
305
+
306
+ if len(sys.argv) != 2:
307
+ print("Usage: python finetune_model.py <model_name>")
308
+ print("Available models: biomistral_7b, qwen3_7b, meditron_7b, internist_7b")
309
+ sys.exit(1)
310
+
311
+ model_name = sys.argv[1]
312
+
313
+ if model_name not in ["biomistral_7b", "qwen3_7b", "meditron_7b", "internist_7b"]:
314
+ print(f"Unknown model: {model_name}")
315
+ sys.exit(1)
316
+
317
+ logger.info(f"Starting fine-tuning job for {model_name}")
318
+
319
+ fine_tuner = HFFineTuner(model_name)
320
+ output_dir = fine_tuner.fine_tune()
321
+
322
+ if output_dir:
323
+ logger.info(f"Fine-tuning job completed successfully for {model_name}")
324
+ print(f"Model saved to: {output_dir}")
325
+ else:
326
+ logger.error(f"Fine-tuning job failed for {model_name}")
327
+ sys.exit(1)
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()