Spaces:
Runtime error
Runtime error
Update src/train.py
Browse files- src/train.py +43 -15
src/train.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
from typing import
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from datasets import load_dataset
|
|
@@ -14,10 +15,32 @@ from transformers import (
|
|
| 14 |
from peft import LoraConfig, TaskType, get_peft_model
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def finetune_lora(
|
| 18 |
base_model: str,
|
| 19 |
dataset_id: str,
|
| 20 |
-
text_column: str,
|
| 21 |
output_dir: str,
|
| 22 |
max_train_samples: int = 2000,
|
| 23 |
max_steps: int = 100,
|
|
@@ -28,8 +51,11 @@ def finetune_lora(
|
|
| 28 |
lora_dropout: float = 0.05,
|
| 29 |
) -> str:
|
| 30 |
ds = load_dataset(dataset_id, split="train")
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
if max_train_samples and max_train_samples > 0:
|
| 35 |
ds = ds.select(range(min(len(ds), int(max_train_samples))))
|
|
@@ -39,15 +65,23 @@ def finetune_lora(
|
|
| 39 |
tokenizer.pad_token = tokenizer.eos_token
|
| 40 |
|
| 41 |
def tok(batch):
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
tokenized = ds.map(
|
| 45 |
|
| 46 |
model = AutoModelForCausalLM.from_pretrained(base_model)
|
| 47 |
model.config.pad_token_id = tokenizer.pad_token_id
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
#
|
|
|
|
| 51 |
lora_cfg = LoraConfig(
|
| 52 |
task_type=TaskType.CAUSAL_LM,
|
| 53 |
r=int(lora_r),
|
|
@@ -72,13 +106,7 @@ def finetune_lora(
|
|
| 72 |
fp16=fp16,
|
| 73 |
)
|
| 74 |
|
| 75 |
-
trainer = Trainer(
|
| 76 |
-
model=model,
|
| 77 |
-
args=args,
|
| 78 |
-
train_dataset=tokenized,
|
| 79 |
-
data_collator=collator,
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
trainer.train()
|
| 83 |
|
| 84 |
adapter_dir = os.path.join(output_dir, "adapter")
|
|
|
|
| 1 |
+
# src/train.py
|
| 2 |
import os
|
| 3 |
+
from typing import Dict, List
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from datasets import load_dataset
|
|
|
|
| 15 |
from peft import LoraConfig, TaskType, get_peft_model
|
| 16 |
|
| 17 |
|
| 18 |
+
def _format_as_chat(tokenizer, ex: Dict) -> str:
|
| 19 |
+
system = (ex.get("system") or "").strip()
|
| 20 |
+
user = (ex.get("user") or "").strip()
|
| 21 |
+
assistant = (ex.get("assistant") or "").strip()
|
| 22 |
+
|
| 23 |
+
# Preferred: model-native chat template (Llama/Qwen/Mistral Instruct, etc.)
|
| 24 |
+
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
|
| 25 |
+
messages: List[Dict[str, str]] = []
|
| 26 |
+
if system:
|
| 27 |
+
messages.append({"role": "system", "content": system})
|
| 28 |
+
messages.append({"role": "user", "content": user})
|
| 29 |
+
messages.append({"role": "assistant", "content": assistant})
|
| 30 |
+
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
| 31 |
+
|
| 32 |
+
# Fallback: simple transcript
|
| 33 |
+
parts = []
|
| 34 |
+
if system:
|
| 35 |
+
parts.append(f"### System:\n{system}")
|
| 36 |
+
parts.append(f"### User:\n{user}")
|
| 37 |
+
parts.append(f"### Assistant:\n{assistant}")
|
| 38 |
+
return "\n\n".join(parts)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
def finetune_lora(
|
| 42 |
base_model: str,
|
| 43 |
dataset_id: str,
|
|
|
|
| 44 |
output_dir: str,
|
| 45 |
max_train_samples: int = 2000,
|
| 46 |
max_steps: int = 100,
|
|
|
|
| 51 |
lora_dropout: float = 0.05,
|
| 52 |
) -> str:
|
| 53 |
ds = load_dataset(dataset_id, split="train")
|
| 54 |
+
|
| 55 |
+
needed = {"system", "user", "assistant"}
|
| 56 |
+
missing = needed.difference(set(ds.column_names))
|
| 57 |
+
if missing:
|
| 58 |
+
return f"ERROR: dataset missing columns {sorted(missing)}. Found: {ds.column_names}"
|
| 59 |
|
| 60 |
if max_train_samples and max_train_samples > 0:
|
| 61 |
ds = ds.select(range(min(len(ds), int(max_train_samples))))
|
|
|
|
| 65 |
tokenizer.pad_token = tokenizer.eos_token
|
| 66 |
|
| 67 |
def tok(batch):
|
| 68 |
+
texts = [_format_as_chat(tokenizer, ex) for ex in batch]
|
| 69 |
+
return tokenizer(texts, truncation=True, max_length=1024)
|
| 70 |
+
|
| 71 |
+
# map with batched=True expects a dict-of-lists; easiest is to build list of dicts per batch
|
| 72 |
+
def batched_map(batch):
|
| 73 |
+
# Convert dict-of-lists to list-of-dicts
|
| 74 |
+
exs = [dict(zip(batch.keys(), vals)) for vals in zip(*batch.values())]
|
| 75 |
+
return tok(exs)
|
| 76 |
|
| 77 |
+
tokenized = ds.map(batched_map, batched=True, remove_columns=ds.column_names)
|
| 78 |
|
| 79 |
model = AutoModelForCausalLM.from_pretrained(base_model)
|
| 80 |
model.config.pad_token_id = tokenizer.pad_token_id
|
| 81 |
|
| 82 |
+
# NOTE: target_modules depends on model architecture.
|
| 83 |
+
# GPT-2 uses c_attn/c_proj; Llama uses q_proj/k_proj/v_proj/o_proj; Qwen varies.
|
| 84 |
+
# Keep GPT-2 defaults here and change if you swap base_model.
|
| 85 |
lora_cfg = LoraConfig(
|
| 86 |
task_type=TaskType.CAUSAL_LM,
|
| 87 |
r=int(lora_r),
|
|
|
|
| 106 |
fp16=fp16,
|
| 107 |
)
|
| 108 |
|
| 109 |
+
trainer = Trainer(model=model, args=args, train_dataset=tokenized, data_collator=collator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
trainer.train()
|
| 111 |
|
| 112 |
adapter_dir = os.path.join(output_dir, "adapter")
|