| |
| """ |
| Fine-tuning script for medical models on Hugging Face infrastructure |
| """ |
| import torch |
| import json |
| import os |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForLanguageModeling |
| ) |
| from datasets import load_dataset |
| from peft import LoraConfig, get_peft_model, TaskType |
| import numpy as np |
| from typing import Dict, List |
| import logging |
| from pathlib import Path |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class HFFineTuner: |
| def __init__(self, model_name: str): |
| self.model_name = model_name |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Fine-tuning {model_name} on device: {self.device}") |
| |
| |
| self.models = { |
| "biomistral_7b": "BioMistral/BioMistral-7B", |
| "qwen3_7b": "Qwen/Qwen2.5-7B-Instruct", |
| "meditron_7b": "epfl-llm/meditron-7b", |
| "internist_7b": "internistai/internist-7b" |
| } |
| |
| |
| self.lora_config = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.1, |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
| ) |
| |
| def load_model_and_tokenizer(self): |
| """Load model and tokenizer for fine-tuning""" |
| model_path = self.models[self.model_name] |
| logger.info(f"Loading {model_path}") |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| trust_remote_code=True |
| ) |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| device_map="auto" if self.device == "cuda" else None, |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
| trust_remote_code=True |
| ) |
| |
| |
| model = get_peft_model(model, self.lora_config) |
| model.print_trainable_parameters() |
| |
| return model, tokenizer |
| |
| def load_and_process_dataset(self): |
| """Load and process MedQA dataset for training""" |
| logger.info("Loading MedQA dataset...") |
| |
| |
| try: |
| dataset = load_dataset("bigbio/med_qa") |
| except: |
| try: |
| dataset = load_dataset("medqa") |
| except: |
| logger.error("Could not load MedQA dataset") |
| return None |
| |
| def process_example(example): |
| |
| if 'question' in example: |
| question = example['question'] |
| elif 'text' in example: |
| question = example['text'] |
| else: |
| question = example['input'] |
| |
| |
| if 'options' in example: |
| options = example['options'] |
| elif 'choices' in example: |
| options = example['choices'] |
| else: |
| options = [] |
| for i in range(5): |
| key = f'option_{i}' if f'option_{i}' in example else f'choice_{i}' |
| if key in example: |
| options.append(example[key]) |
| |
| |
| if 'answer' in example: |
| answer = example['answer'] |
| elif 'label' in example: |
| answer = example['label'] |
| else: |
| answer = example['output'] |
| |
| return { |
| 'question': question, |
| 'options': options, |
| 'answer': answer |
| } |
| |
| |
| processed_dataset = dataset.map(process_example) |
| |
| |
| def create_prompt(example): |
| question = example['question'] |
| options = example['options'] |
| answer = example['answer'] |
| |
| options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)]) |
| |
| if "qwen" in self.model_name.lower(): |
| prompt = f"""<|im_start|>user |
| {question} |
| |
| {options_text} |
| |
| Please select the correct answer (A, B, C, D, or E).<|im_end|> |
| <|im_start|>assistant |
| The correct answer is {answer}.<|im_end|>""" |
| elif "mistral" in self.model_name.lower() or "biomistral" in self.model_name.lower(): |
| prompt = f"""<s>[INST] {question} |
| |
| {options_text} |
| |
| Please select the correct answer (A, B, C, D, or E). [/INST] The correct answer is {answer}.</s>""" |
| else: |
| |
| prompt = f"""Question: {question} |
| |
| {options_text} |
| |
| Answer: {answer}""" |
| |
| return {"text": prompt} |
| |
| |
| formatted_dataset = processed_dataset.map(create_prompt) |
| |
| |
| train_val_split = formatted_dataset['train'].train_test_split(test_size=0.2, seed=42) |
| |
| return { |
| 'train': train_val_split['train'], |
| 'validation': train_val_split['test'], |
| 'test': formatted_dataset['test'] |
| } |
| |
| def tokenize_dataset(self, dataset, tokenizer): |
| """Tokenize dataset for training""" |
| def tokenize_function(examples): |
| tokenized = tokenizer( |
| examples['text'], |
| truncation=True, |
| padding=False, |
| max_length=2048, |
| return_tensors=None |
| ) |
| tokenized['labels'] = tokenized['input_ids'].copy() |
| return tokenized |
| |
| tokenized_dataset = dataset.map( |
| tokenize_function, |
| batched=True, |
| remove_columns=dataset['train'].column_names |
| ) |
| |
| return tokenized_dataset |
| |
| def fine_tune(self): |
| """Main fine-tuning function""" |
| logger.info(f"Starting fine-tuning for {self.model_name}") |
| |
| |
| model, tokenizer = self.load_model_and_tokenizer() |
| |
| |
| dataset = self.load_and_process_dataset() |
| if dataset is None: |
| return |
| |
| |
| tokenized_dataset = self.tokenize_dataset(dataset, tokenizer) |
| |
| |
| training_args = TrainingArguments( |
| output_dir=f"/tmp/{self.model_name}_finetuned", |
| num_train_epochs=3, |
| per_device_train_batch_size=4, |
| per_device_eval_batch_size=8, |
| gradient_accumulation_steps=4, |
| learning_rate=2e-5, |
| weight_decay=0.01, |
| warmup_ratio=0.1, |
| logging_steps=10, |
| eval_steps=100, |
| save_steps=500, |
| save_total_limit=2, |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| fp16=True, |
| evaluation_strategy="steps", |
| save_strategy="steps", |
| report_to="none", |
| remove_unused_columns=False, |
| ) |
| |
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=tokenizer, |
| mlm=False, |
| ) |
| |
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_dataset['train'], |
| eval_dataset=tokenized_dataset['validation'], |
| data_collator=data_collator, |
| ) |
| |
| |
| logger.info("Starting training...") |
| trainer.train() |
| |
| |
| output_dir = f"/tmp/{self.model_name}_finetuned" |
| trainer.save_model(output_dir) |
| tokenizer.save_pretrained(output_dir) |
| |
| |
| training_metrics = trainer.evaluate() |
| with open(f"{output_dir}/training_metrics.json", 'w') as f: |
| json.dump(training_metrics, f, indent=2) |
| |
| logger.info(f"Fine-tuning completed for {self.model_name}") |
| logger.info(f"Model saved to: {output_dir}") |
| |
| |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| |
| |
| repo_name = f"medical-{self.model_name}-finetuned" |
| try: |
| api.create_repo(repo_name, exist_ok=True) |
| except: |
| pass |
| |
| |
| api.upload_folder( |
| folder_path=output_dir, |
| repo_id=repo_name, |
| repo_type="model" |
| ) |
| |
| logger.info(f"Fine-tuned model uploaded to {repo_name}") |
| |
| |
| api.upload_file( |
| path_or_fileobj=f"{output_dir}/training_metrics.json", |
| path_in_repo="training_metrics.json", |
| repo_id=repo_name, |
| repo_type="model" |
| ) |
| |
| except Exception as e: |
| logger.warning(f"Could not upload model to HF Hub: {e}") |
| |
| return output_dir |
|
|
|
|
| def main(): |
| """Main function for HF fine-tuning job""" |
| import sys |
| |
| if len(sys.argv) != 2: |
| print("Usage: python finetune_model.py <model_name>") |
| print("Available models: biomistral_7b, qwen3_7b, meditron_7b, internist_7b") |
| sys.exit(1) |
| |
| model_name = sys.argv[1] |
| |
| if model_name not in ["biomistral_7b", "qwen3_7b", "meditron_7b", "internist_7b"]: |
| print(f"Unknown model: {model_name}") |
| sys.exit(1) |
| |
| logger.info(f"Starting fine-tuning job for {model_name}") |
| |
| fine_tuner = HFFineTuner(model_name) |
| output_dir = fine_tuner.fine_tune() |
| |
| if output_dir: |
| logger.info(f"Fine-tuning job completed successfully for {model_name}") |
| print(f"Model saved to: {output_dir}") |
| else: |
| logger.error(f"Fine-tuning job failed for {model_name}") |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|