medical-benchmark-scripts / finetune_model.py
airevartis's picture
Upload finetune_model.py with huggingface_hub
eb6e3d8 verified
#!/usr/bin/env python3
"""
Fine-tuning script for medical models on Hugging Face infrastructure
"""
import torch
import json
import os
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
from typing import Dict, List
import logging
from pathlib import Path
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class HFFineTuner:
def __init__(self, model_name: str):
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Fine-tuning {model_name} on device: {self.device}")
# Model configurations
self.models = {
"biomistral_7b": "BioMistral/BioMistral-7B",
"qwen3_7b": "Qwen/Qwen2.5-7B-Instruct",
"meditron_7b": "epfl-llm/meditron-7b",
"internist_7b": "internistai/internist-7b"
}
# LoRA configuration
self.lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
def load_model_and_tokenizer(self):
"""Load model and tokenizer for fine-tuning"""
model_path = self.models[self.model_name]
logger.info(f"Loading {model_path}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto" if self.device == "cuda" else None,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
trust_remote_code=True
)
# Apply LoRA
model = get_peft_model(model, self.lora_config)
model.print_trainable_parameters()
return model, tokenizer
def load_and_process_dataset(self):
"""Load and process MedQA dataset for training"""
logger.info("Loading MedQA dataset...")
# Load dataset
try:
dataset = load_dataset("bigbio/med_qa")
except:
try:
dataset = load_dataset("medqa")
except:
logger.error("Could not load MedQA dataset")
return None
def process_example(example):
# Handle different dataset formats
if 'question' in example:
question = example['question']
elif 'text' in example:
question = example['text']
else:
question = example['input']
# Handle multiple choice options
if 'options' in example:
options = example['options']
elif 'choices' in example:
options = example['choices']
else:
options = []
for i in range(5):
key = f'option_{i}' if f'option_{i}' in example else f'choice_{i}'
if key in example:
options.append(example[key])
# Get answer
if 'answer' in example:
answer = example['answer']
elif 'label' in example:
answer = example['label']
else:
answer = example['output']
return {
'question': question,
'options': options,
'answer': answer
}
# Process dataset
processed_dataset = dataset.map(process_example)
# Create training prompts
def create_prompt(example):
question = example['question']
options = example['options']
answer = example['answer']
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
if "qwen" in self.model_name.lower():
prompt = f"""<|im_start|>user
{question}
{options_text}
Please select the correct answer (A, B, C, D, or E).<|im_end|>
<|im_start|>assistant
The correct answer is {answer}.<|im_end|>"""
elif "mistral" in self.model_name.lower() or "biomistral" in self.model_name.lower():
prompt = f"""<s>[INST] {question}
{options_text}
Please select the correct answer (A, B, C, D, or E). [/INST] The correct answer is {answer}.</s>"""
else:
# Generic format
prompt = f"""Question: {question}
{options_text}
Answer: {answer}"""
return {"text": prompt}
# Format for training
formatted_dataset = processed_dataset.map(create_prompt)
# Split into train/validation
train_val_split = formatted_dataset['train'].train_test_split(test_size=0.2, seed=42)
return {
'train': train_val_split['train'],
'validation': train_val_split['test'],
'test': formatted_dataset['test']
}
def tokenize_dataset(self, dataset, tokenizer):
"""Tokenize dataset for training"""
def tokenize_function(examples):
tokenized = tokenizer(
examples['text'],
truncation=True,
padding=False,
max_length=2048,
return_tensors=None
)
tokenized['labels'] = tokenized['input_ids'].copy()
return tokenized
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset['train'].column_names
)
return tokenized_dataset
def fine_tune(self):
"""Main fine-tuning function"""
logger.info(f"Starting fine-tuning for {self.model_name}")
# Load model and tokenizer
model, tokenizer = self.load_model_and_tokenizer()
# Load and process dataset
dataset = self.load_and_process_dataset()
if dataset is None:
return
# Tokenize dataset
tokenized_dataset = self.tokenize_dataset(dataset, tokenizer)
# Training arguments
training_args = TrainingArguments(
output_dir=f"/tmp/{self.model_name}_finetuned",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=8,
gradient_accumulation_steps=4,
learning_rate=2e-5,
weight_decay=0.01,
warmup_ratio=0.1,
logging_steps=10,
eval_steps=100,
save_steps=500,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=True,
evaluation_strategy="steps",
save_strategy="steps",
report_to="none",
remove_unused_columns=False,
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset['train'],
eval_dataset=tokenized_dataset['validation'],
data_collator=data_collator,
)
# Train
logger.info("Starting training...")
trainer.train()
# Save model
output_dir = f"/tmp/{self.model_name}_finetuned"
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
# Save training metrics
training_metrics = trainer.evaluate()
with open(f"{output_dir}/training_metrics.json", 'w') as f:
json.dump(training_metrics, f, indent=2)
logger.info(f"Fine-tuning completed for {self.model_name}")
logger.info(f"Model saved to: {output_dir}")
# Upload to HF Hub
try:
from huggingface_hub import HfApi
api = HfApi()
# Create repository for fine-tuned model
repo_name = f"medical-{self.model_name}-finetuned"
try:
api.create_repo(repo_name, exist_ok=True)
except:
pass
# Upload model files
api.upload_folder(
folder_path=output_dir,
repo_id=repo_name,
repo_type="model"
)
logger.info(f"Fine-tuned model uploaded to {repo_name}")
# Upload training metrics
api.upload_file(
path_or_fileobj=f"{output_dir}/training_metrics.json",
path_in_repo="training_metrics.json",
repo_id=repo_name,
repo_type="model"
)
except Exception as e:
logger.warning(f"Could not upload model to HF Hub: {e}")
return output_dir
def main():
"""Main function for HF fine-tuning job"""
import sys
if len(sys.argv) != 2:
print("Usage: python finetune_model.py <model_name>")
print("Available models: biomistral_7b, qwen3_7b, meditron_7b, internist_7b")
sys.exit(1)
model_name = sys.argv[1]
if model_name not in ["biomistral_7b", "qwen3_7b", "meditron_7b", "internist_7b"]:
print(f"Unknown model: {model_name}")
sys.exit(1)
logger.info(f"Starting fine-tuning job for {model_name}")
fine_tuner = HFFineTuner(model_name)
output_dir = fine_tuner.fine_tune()
if output_dir:
logger.info(f"Fine-tuning job completed successfully for {model_name}")
print(f"Model saved to: {output_dir}")
else:
logger.error(f"Fine-tuning job failed for {model_name}")
sys.exit(1)
if __name__ == "__main__":
main()