import os import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, ) from peft import LoraConfig from trl import SFTTrainer # --- CONFIGURATION --- MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" DATASET_NAME = "ceperaltab/elixir-golden-dataset" OUTPUT_DIR = "elixir-model-qwen" def main(): print(f"Loading dataset from {DATASET_NAME}...") # 1. Load Dataset try: dataset = load_dataset(DATASET_NAME, split="train") except Exception as e: print(f"Error loading dataset: {e}") return # 2. Quantization Config (4-bit for memory efficiency) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) print(f"Loading base model: {MODEL_NAME}...") # 3. Load Model model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) # 4. Load Tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # 5. LoRA Config peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, r=64, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) # 6. Formatting Function for Chat Dataset (TRL v0.8.6 API) def formatting_prompts_func(examples): output_texts = [] for messages in examples['messages']: text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) output_texts.append(text) return output_texts print("Starting SFTTrainer setup...") # 7. Training Arguments (TRL v0.8.6 uses TrainingArguments from transformers) training_args = TrainingArguments( output_dir=OUTPUT_DIR, per_device_train_batch_size=1, gradient_accumulation_steps=8, # Compensate for smaller batch learning_rate=2e-4, logging_steps=10, num_train_epochs=1, optim="paged_adamw_32bit", fp16=True, group_by_length=True, gradient_checkpointing=True, # Save memory save_strategy="epoch", report_to="none", push_to_hub=True, hub_model_id=f"ceperaltab/{OUTPUT_DIR}", ) # 8. SFTTrainer (TRL v0.8.6 API) trainer = SFTTrainer( model=model, train_dataset=dataset, peft_config=peft_config, formatting_func=formatting_prompts_func, max_seq_length=1024, # Reduced for T4 memory tokenizer=tokenizer, args=training_args, ) print("Starting training...") trainer.train() print(f"Saving model to {OUTPUT_DIR}...") trainer.save_model(OUTPUT_DIR) print("Done!") if __name__ == "__main__": main()