# src/train.py import os from typing import Dict, List import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments, ) from peft import LoraConfig, TaskType, get_peft_model def _format_as_chat(tokenizer, ex: Dict) -> str: system = (ex.get("system") or "").strip() user = (ex.get("user") or "").strip() assistant = (ex.get("assistant") or "").strip() # Preferred: model-native chat template (Llama/Qwen/Mistral Instruct, etc.) if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: messages: List[Dict[str, str]] = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": user}) messages.append({"role": "assistant", "content": assistant}) return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) # Fallback: simple transcript parts = [] if system: parts.append(f"### System:\n{system}") parts.append(f"### User:\n{user}") parts.append(f"### Assistant:\n{assistant}") return "\n\n".join(parts) def finetune_lora( base_model: str, dataset_id: str, output_dir: str, max_train_samples: int = 2000, max_steps: int = 100, learning_rate: float = 2e-4, batch_size: int = 2, lora_r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.05, ) -> str: ds = load_dataset(dataset_id, split="train") needed = {"system", "user", "assistant"} missing = needed.difference(set(ds.column_names)) if missing: return f"ERROR: dataset missing columns {sorted(missing)}. Found: {ds.column_names}" if max_train_samples and max_train_samples > 0: ds = ds.select(range(min(len(ds), int(max_train_samples)))) tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token def tok(batch): texts = [_format_as_chat(tokenizer, ex) for ex in batch] return tokenizer(texts, truncation=True, max_length=1024) # map with batched=True expects a dict-of-lists; easiest is to build list of dicts per batch def batched_map(batch): # Convert dict-of-lists to list-of-dicts exs = [dict(zip(batch.keys(), vals)) for vals in zip(*batch.values())] return tok(exs) tokenized = ds.map(batched_map, batched=True, remove_columns=ds.column_names) model = AutoModelForCausalLM.from_pretrained(base_model) model.config.pad_token_id = tokenizer.pad_token_id # NOTE: target_modules depends on model architecture. # GPT-2 uses c_attn/c_proj; Llama uses q_proj/k_proj/v_proj/o_proj; Qwen varies. # Keep GPT-2 defaults here and change if you swap base_model. lora_cfg = LoraConfig( task_type=TaskType.CAUSAL_LM, r=int(lora_r), lora_alpha=int(lora_alpha), lora_dropout=float(lora_dropout), bias="none", target_modules=["c_attn", "c_proj"], ) model = get_peft_model(model, lora_cfg) collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) fp16 = torch.cuda.is_available() args = TrainingArguments( output_dir=output_dir, per_device_train_batch_size=int(batch_size), learning_rate=float(learning_rate), max_steps=int(max_steps), logging_steps=10, save_steps=0, report_to=[], fp16=fp16, ) trainer = Trainer(model=model, args=args, train_dataset=tokenized, data_collator=collator) trainer.train() adapter_dir = os.path.join(output_dir, "adapter") model.save_pretrained(adapter_dir) tokenizer.save_pretrained(adapter_dir) return f"Saved LoRA adapter + tokenizer to {adapter_dir}"