Spaces:
Runtime error
Runtime error
File size: 3,916 Bytes
09d1245 00ae6eb 09d1245 00ae6eb 09d1245 00ae6eb 09d1245 00ae6eb 09d1245 00ae6eb 09d1245 00ae6eb 09d1245 00ae6eb 09d1245 00ae6eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | # src/train.py
import os
from typing import Dict, List
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from peft import LoraConfig, TaskType, get_peft_model
def _format_as_chat(tokenizer, ex: Dict) -> str:
system = (ex.get("system") or "").strip()
user = (ex.get("user") or "").strip()
assistant = (ex.get("assistant") or "").strip()
# Preferred: model-native chat template (Llama/Qwen/Mistral Instruct, etc.)
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
messages: List[Dict[str, str]] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
# Fallback: simple transcript
parts = []
if system:
parts.append(f"### System:\n{system}")
parts.append(f"### User:\n{user}")
parts.append(f"### Assistant:\n{assistant}")
return "\n\n".join(parts)
def finetune_lora(
base_model: str,
dataset_id: str,
output_dir: str,
max_train_samples: int = 2000,
max_steps: int = 100,
learning_rate: float = 2e-4,
batch_size: int = 2,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
) -> str:
ds = load_dataset(dataset_id, split="train")
needed = {"system", "user", "assistant"}
missing = needed.difference(set(ds.column_names))
if missing:
return f"ERROR: dataset missing columns {sorted(missing)}. Found: {ds.column_names}"
if max_train_samples and max_train_samples > 0:
ds = ds.select(range(min(len(ds), int(max_train_samples))))
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tok(batch):
texts = [_format_as_chat(tokenizer, ex) for ex in batch]
return tokenizer(texts, truncation=True, max_length=1024)
# map with batched=True expects a dict-of-lists; easiest is to build list of dicts per batch
def batched_map(batch):
# Convert dict-of-lists to list-of-dicts
exs = [dict(zip(batch.keys(), vals)) for vals in zip(*batch.values())]
return tok(exs)
tokenized = ds.map(batched_map, batched=True, remove_columns=ds.column_names)
model = AutoModelForCausalLM.from_pretrained(base_model)
model.config.pad_token_id = tokenizer.pad_token_id
# NOTE: target_modules depends on model architecture.
# GPT-2 uses c_attn/c_proj; Llama uses q_proj/k_proj/v_proj/o_proj; Qwen varies.
# Keep GPT-2 defaults here and change if you swap base_model.
lora_cfg = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=int(lora_r),
lora_alpha=int(lora_alpha),
lora_dropout=float(lora_dropout),
bias="none",
target_modules=["c_attn", "c_proj"],
)
model = get_peft_model(model, lora_cfg)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
fp16 = torch.cuda.is_available()
args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=int(batch_size),
learning_rate=float(learning_rate),
max_steps=int(max_steps),
logging_steps=10,
save_steps=0,
report_to=[],
fp16=fp16,
)
trainer = Trainer(model=model, args=args, train_dataset=tokenized, data_collator=collator)
trainer.train()
adapter_dir = os.path.join(output_dir, "adapter")
model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
return f"Saved LoRA adapter + tokenizer to {adapter_dir}" |