import os import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, ) from peft import LoraConfig from trl import SFTTrainer, SFTConfig # --- CONFIGURATION --- # Base model: Using a quantized Llama 3 or Mistral is recommended for consumer GPUs. # Ensure you have access to the model on Hugging Face (might need login). MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" DATASET_NAME = "ceperaltab/elixir-golden-dataset" OUTPUT_DIR = "elixir-model-qwen" def main(): print(f"Loading dataset from {DATASET_NAME}...") # 1. Load Dataset try: # Load directly from HF Hub dataset = load_dataset(DATASET_NAME, split="train") except Exception as e: print(f"Error loading dataset: {e}") return # 2. Quantization Config (4-bit for memory efficiency) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) print(f"Loading base model: {MODEL_NAME}...") # 3. Load Model model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) # 4. Load Tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Critical for fp16 training # 5. LoRA Config (Parameter Efficient Fine-Tuning) peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, r=64, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) # 6. Formatting Function for Chat Dataset # Converts {"messages": [...]} into the model's expected prompt format def formatting_prompts_func(examples): output_texts = [] for messages in examples['messages']: # Apply chat template (e.g., <|begin_of_text|><|start_header_id|>user...) # We don't tokenize yet, SFTTrainer handles it text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) output_texts.append(text) return output_texts print("Starting SFTTrainer setup...") # 7. Trainer trainer = SFTTrainer( model=model, train_dataset=dataset, peft_config=peft_config, formatting_func=formatting_prompts_func, tokenizer=tokenizer, args=SFTConfig( output_dir=OUTPUT_DIR, max_seq_length=2048, # Moved here per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-4, logging_steps=10, num_train_epochs=1, optim="paged_adamw_32bit", fp16=True, group_by_length=True, save_strategy="epoch", report_to="none", push_to_hub=True, hub_model_id=f"ceperaltab/{OUTPUT_DIR}", dataset_text_field="text", # SFTConfig requires this or packing, though we use formatting_func ), ) print("Starting training...") trainer.train() print(f"Saving model to {OUTPUT_DIR}...") trainer.save_model(OUTPUT_DIR) print("Done!") if __name__ == "__main__": main()