| """ | |
| Fine-tuning Qwen2.5-3B-Instruct for football prediction extraction | |
| Fixes from original: target_modules, validation split, scheduler, checkpoint saving | |
| """ | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | |
| import torch | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" | |
| OUTPUT_DIR = "./football-extractor" | |
| TRAIN_FILE = "train_dataset.jsonl" | |
| VAL_FILE = "val_dataset.jsonl" | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOAD DATA | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| dataset = load_dataset("json", data_files={"train": TRAIN_FILE, "validation": VAL_FILE}) | |
| print(f"Train: {len(dataset['train'])} | Val: {len(dataset['validation'])}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # TOKENIZER | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" # important for causal LM training | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # QUANTIZATION (4-bit QLoRA) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, # bfloat16 is more stable than float16 | |
| bnb_4bit_use_double_quant=True, # saves a bit more VRAM | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| attn_implementation="eager", # avoids flash-attn issues on Colab | |
| ) | |
| model.config.use_cache = False # required for gradient checkpointing | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # LORA CONFIG | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| lora_config = LoraConfig( | |
| r=8, # smaller r is fine for simple extraction | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| # Explicitly target attention + MLP layers for Qwen2.5 | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj" | |
| ], | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # FORMAT FUNCTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def format_example(example): | |
| """Apply Qwen2.5 chat template to each training example.""" | |
| return tokenizer.apply_chat_template( | |
| example["messages"], | |
| tokenize=False, | |
| add_generation_prompt=False | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # TRAINING ARGS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| training_args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, # effective batch = 4 | |
| gradient_checkpointing=True, # saves VRAM | |
| learning_rate=2e-4, | |
| num_train_epochs=3, | |
| lr_scheduler_type="cosine", # smooth decay | |
| warmup_ratio=0.05, # 5% warmup steps | |
| logging_steps=10, | |
| eval_strategy="epoch", # evaluate after each epoch | |
| save_strategy="epoch", # save checkpoint each epoch | |
| save_total_limit=2, # keep only last 2 checkpoints | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| fp16=False, | |
| bf16=True, # use bfloat16 if your GPU supports it | |
| report_to="none", # set to "wandb" if you want tracking | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # TRAINER | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| trainer = SFTTrainer( | |
| model=model, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["validation"], | |
| peft_config=lora_config, | |
| args=training_args, | |
| formatting_func=format_example, | |
| max_seq_length=512, # extraction tasks are short | |
| ) | |
| trainer.train() | |
| trainer.save_model(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| print(f"β Adapter saved to {OUTPUT_DIR}") | |