| |
|
| |
|
| |
|
| |
|
| | from __future__ import annotations
|
| |
|
| | import argparse
|
| | from pathlib import Path
|
| | from typing import Dict, List, Union
|
| |
|
| | import torch
|
| | from datasets import load_dataset
|
| | from transformers import (
|
| | AutoTokenizer,
|
| | AutoModelForCausalLM,
|
| | BitsAndBytesConfig,
|
| | Trainer,
|
| | TrainingArguments,
|
| | )
|
| | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| |
|
| |
|
| | def parse_args():
|
| | ap = argparse.ArgumentParser()
|
| | ap.add_argument("--base", type=str, required=True, help="Base model id or path")
|
| | ap.add_argument("--data", type=str, required=True, help="JSONL with chat messages")
|
| | ap.add_argument("--out", type=str, required=True, help="Output dir for adapter")
|
| | ap.add_argument("--epochs", type=int, default=2)
|
| | ap.add_argument("--bsz", type=int, default=8)
|
| | ap.add_argument("--grad_accum", type=int, default=1)
|
| | ap.add_argument("--cutoff_len", type=int, default=2048)
|
| | ap.add_argument("--lr", type=float, default=2e-4)
|
| | ap.add_argument("--lora_r", type=int, default=16)
|
| | ap.add_argument("--lora_alpha", type=int, default=32)
|
| | ap.add_argument("--lora_dropout", type=float, default=0.05)
|
| | ap.add_argument("--debug", action="store_true")
|
| | return ap.parse_args()
|
| |
|
| |
|
| | def device_supports_bf16() -> bool:
|
| | if not torch.cuda.is_available():
|
| | return False
|
| | major, _ = torch.cuda.get_device_capability(0)
|
| | return major >= 8
|
| |
|
| |
|
| | def build_tokenizer(base_id: str):
|
| | tok = AutoTokenizer.from_pretrained(base_id, use_fast=True)
|
| | if tok.pad_token is None:
|
| | tok.pad_token = tok.eos_token
|
| | tok.padding_side = "right"
|
| | return tok
|
| |
|
| |
|
| | def _to_ids(x: Union[torch.Tensor, List[int], Dict[str, List[int]]]) -> List[int]:
|
| | if isinstance(x, torch.Tensor):
|
| | return x.detach().cpu().tolist()[0] if x.ndim == 2 else x.detach().cpu().tolist()
|
| | if isinstance(x, dict) and "input_ids" in x:
|
| | return x["input_ids"]
|
| | if isinstance(x, (list, tuple)):
|
| | return list(x)
|
| | raise TypeError(f"Unsupported chat template return type: {type(x)}")
|
| |
|
| |
|
| | def chat_to_ids(tokenizer: AutoTokenizer, messages: List[Dict], max_len: int):
|
| |
|
| |
|
| | if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
|
| | out = tokenizer.apply_chat_template(
|
| | messages,
|
| | tokenize=True,
|
| | add_generation_prompt=False,
|
| | return_tensors="pt",
|
| | max_length=max_len,
|
| | truncation=True,
|
| | )
|
| | ids = _to_ids(out)
|
| | attn = [1] * len(ids)
|
| | return {"input_ids": ids, "attention_mask": attn}
|
| |
|
| |
|
| | lines = []
|
| | for m in messages:
|
| | role = m.get("role", "user")
|
| | content = m.get("content", "")
|
| | lines.append(f"{role}:\n{content}\n")
|
| | text = "\n".join(lines)
|
| | enc = tokenizer(text, max_length=max_len, truncation=True)
|
| | return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
|
| |
|
| |
|
| | def collate_pad(tokenizer: AutoTokenizer):
|
| | pad_id = tokenizer.pad_token_id
|
| |
|
| | def _fn(batch: List[Dict[str, List[int]]]):
|
| | max_len = max(len(x["input_ids"]) for x in batch)
|
| | input_ids, attn, labels = [], [], []
|
| | for x in batch:
|
| | ids = x["input_ids"]
|
| | am = x["attention_mask"]
|
| | pad_n = max_len - len(ids)
|
| | input_ids.append(ids + [pad_id] * pad_n)
|
| | attn.append(am + [0] * pad_n)
|
| | labels.append(ids + [-100] * pad_n)
|
| | return {
|
| | "input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| | "attention_mask": torch.tensor(attn, dtype=torch.long),
|
| | "labels": torch.tensor(labels, dtype=torch.long),
|
| | }
|
| |
|
| | return _fn
|
| |
|
| |
|
| | def guess_lora_targets(model: torch.nn.Module) -> List[str]:
|
| | prefs = [
|
| | "q_proj",
|
| | "k_proj",
|
| | "v_proj",
|
| | "o_proj",
|
| | "gate_proj",
|
| | "up_proj",
|
| | "down_proj",
|
| | "wi",
|
| | "wo",
|
| | "w1",
|
| | "w2",
|
| | "w3",
|
| | "out_proj",
|
| | ]
|
| | found = set()
|
| | for name, _ in model.named_modules():
|
| | for p in prefs:
|
| | if p in name:
|
| | found.add(p)
|
| | return sorted(found) if found else ["Linear"]
|
| |
|
| |
|
| | def main():
|
| | args = parse_args()
|
| | base_id = args.base
|
| | data_path = Path(args.data)
|
| | out_dir = Path(args.out)
|
| | out_dir.mkdir(parents=True, exist_ok=True)
|
| |
|
| | tokenizer = build_tokenizer(base_id)
|
| |
|
| | ds = load_dataset("json", data_files=str(data_path), split="train")
|
| |
|
| | def map_row(ex):
|
| | return chat_to_ids(tokenizer, ex["messages"], args.cutoff_len)
|
| |
|
| |
|
| | ds = ds.map(map_row, remove_columns=ds.column_names)
|
| |
|
| | collate = collate_pad(tokenizer)
|
| |
|
| | quant = BitsAndBytesConfig(
|
| | load_in_4bit=True,
|
| | bnb_4bit_quant_type="nf4",
|
| | bnb_4bit_use_double_quant=True,
|
| | )
|
| |
|
| | use_bf16 = device_supports_bf16()
|
| | torch_dtype = torch.bfloat16 if use_bf16 else torch.float16
|
| | torch.backends.cuda.matmul.allow_tf32 = True
|
| |
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | base_id,
|
| | device_map="auto",
|
| | quantization_config=quant,
|
| | torch_dtype=torch_dtype,
|
| | )
|
| |
|
| | model = prepare_model_for_kbit_training(model)
|
| | lconf = LoraConfig(
|
| | r=args.lora_r,
|
| | lora_alpha=args.lora_alpha,
|
| | lora_dropout=args.lora_dropout,
|
| | bias="none",
|
| | task_type="CAUSAL_LM",
|
| | target_modules=guess_lora_targets(model),
|
| | )
|
| | model = get_peft_model(model, lconf)
|
| |
|
| | train_args = TrainingArguments(
|
| | output_dir=str(out_dir),
|
| | num_train_epochs=args.epochs,
|
| | per_device_train_batch_size=args.bsz,
|
| | gradient_accumulation_steps=args.grad_accum,
|
| | learning_rate=args.lr,
|
| | lr_scheduler_type="cosine",
|
| | warmup_ratio=0.03,
|
| | logging_steps=5,
|
| | save_steps=100,
|
| | bf16=use_bf16,
|
| | fp16=not use_bf16,
|
| | optim="paged_adamw_8bit",
|
| | remove_unused_columns=False,
|
| | dataloader_num_workers=2,
|
| | report_to=[],
|
| | )
|
| |
|
| | tr = Trainer(
|
| | model=model,
|
| | args=train_args,
|
| | train_dataset=ds,
|
| | data_collator=collate,
|
| | tokenizer=tokenizer,
|
| | )
|
| |
|
| | if args.debug:
|
| | batch = next(iter(tr.get_train_dataloader()))
|
| | print("[debug] batch keys:", list(batch.keys()))
|
| | for k, v in batch.items():
|
| | if isinstance(v, torch.Tensor):
|
| | print(f"[debug] {k}: shape={tuple(v.shape)} dtype={v.dtype}")
|
| |
|
| | tr.train()
|
| |
|
| | model.save_pretrained(str(out_dir))
|
| | tokenizer.save_pretrained(str(out_dir))
|
| | print("[ok] saved adapter to", out_dir.resolve())
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|