Yifei Wang
Deploy HF Space demo (clean)
707a2d1
from __future__ import annotations
import os
import glob
from pathlib import Path
from typing import Any, Dict
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
Trainer,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from numen_scriptorium.paths import ROOT
from numen_scriptorium.repro import set_global_seed
PRESET_CONFIGS = {
"t4": {
"max_seq_len": 1024,
"micro_batch_size": 1,
"gradient_accumulation_steps": 16,
"lora_r": 16,
"lora_alpha": 32,
"learning_rate": 1e-4,
"fp16": True,
"bf16": False,
},
"a100": {
"max_seq_len": 2048,
"micro_batch_size": 2,
"gradient_accumulation_steps": 8,
"lora_r": 32,
"lora_alpha": 64,
"learning_rate": 1e-4,
"fp16": False,
"bf16": True,
},
}
COMMON_QWEN_TARGET_MODULES = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"down_proj",
"gate_proj",
]
def _resolve_path(path_like: str | Path) -> str:
p = Path(path_like)
if not p.is_absolute():
p = ROOT / p
return str(p)
def encode_dataset(tokenizer, dataset, max_seq_len: int):
def build_and_tokenize(example):
instruction = (example.get("instruction") or "").strip()
inp = (example.get("input") or "").strip()
out = (example.get("output") or "").strip()
if inp:
prompt = f"指令:{instruction}\n输入:{inp}\n回答:"
else:
prompt = f"指令:{instruction}\n回答:"
prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
answer_ids = tokenizer(out, add_special_tokens=False)["input_ids"]
if tokenizer.eos_token_id is not None:
answer_ids = answer_ids + [tokenizer.eos_token_id]
input_ids = prompt_ids + answer_ids
labels = [-100] * len(prompt_ids) + answer_ids
input_ids = input_ids[:max_seq_len]
labels = labels[:max_seq_len]
attention_mask = [1] * len(input_ids)
pad_id = tokenizer.pad_token_id
pad_len = max_seq_len - len(input_ids)
if pad_len > 0:
input_ids = input_ids + [pad_id] * pad_len
attention_mask = attention_mask + [0] * pad_len
labels = labels + [-100] * pad_len
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
return dataset.map(build_and_tokenize, remove_columns=dataset.column_names)
def get_latest_checkpoint(output_dir: str):
ckpts = glob.glob(os.path.join(output_dir, "checkpoint-*"))
if not ckpts:
return None
def step_num(p):
name = os.path.basename(p)
try:
return int(name.split("-")[-1])
except Exception:
return -1
ckpts = sorted(ckpts, key=step_num)
return ckpts[-1]
def get_lora_target_modules(model):
found = set()
for name, _ in model.named_modules():
for candidate in COMMON_QWEN_TARGET_MODULES:
if name == candidate or name.endswith(f".{candidate}"):
found.add(candidate)
selected = [m for m in COMMON_QWEN_TARGET_MODULES if m in found]
if not selected:
raise ValueError("No expected LoRA target modules found.")
return selected
def pick_compute_dtype(use_bf16: bool):
if use_bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
def resolve_training_config(config: Dict[str, Any], max_seq_len_override: int | None):
preset = config.get("preset", "a100")
cfg = dict(PRESET_CONFIGS[preset])
cfg.update({k: v for k, v in config.items() if v is not None})
if torch.cuda.is_available():
total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
if total_vram_gb < 40:
cfg["max_seq_len"] = min(int(cfg["max_seq_len"]), 1024)
cfg["micro_batch_size"] = min(int(cfg["micro_batch_size"]), 1)
cfg["gradient_accumulation_steps"] = max(int(cfg["gradient_accumulation_steps"]), 16)
if max_seq_len_override is not None:
cfg["max_seq_len"] = max_seq_len_override
return cfg
def train_from_config(config: Dict[str, Any], resume: str | None = None, max_seq_len_override: int | None = None):
cfg = resolve_training_config(config, max_seq_len_override)
seed = int(cfg.get("seed", 42))
deterministic = bool(cfg.get("deterministic", False))
set_global_seed(seed, deterministic=deterministic)
base_model = cfg.get("base_model", "Qwen/Qwen2.5-7B-Instruct")
output_dir = _resolve_path(cfg.get("output_dir", "outputs/qwen2_5_7b_boh_qlora"))
train_file = _resolve_path(cfg.get("train_file", "data_split/train.jsonl"))
val_file = _resolve_path(cfg.get("val_file", "data_split/val.jsonl"))
report_to = cfg.get("report_to", "wandb")
use_bf16 = bool(cfg.get("bf16", False) and torch.cuda.is_available() and torch.cuda.is_bf16_supported())
use_fp16 = not use_bf16
compute_dtype = pick_compute_dtype(use_bf16)
report_to_value = [] if str(report_to).lower() == "none" else report_to
model_name = base_model.split("/")[-1].lower().replace(".", "_").replace("-", "_")
run_name = cfg.get("run_name") or f"{model_name}_{cfg['preset']}_qlora"
print("=" * 80)
print("Effective training config")
print(f"base_model: {base_model}")
print(f"preset: {cfg['preset']}")
print(f"max_seq_len: {cfg['max_seq_len']}")
print(f"micro_batch_size: {cfg['micro_batch_size']}")
print(f"gradient_accumulation_steps: {cfg['gradient_accumulation_steps']}")
print(f"learning_rate: {cfg['learning_rate']}")
print(f"lora(r/alpha): {cfg['lora_r']}/{cfg['lora_alpha']}")
print(f"fp16: {use_fp16}, bf16: {use_bf16}")
print(f"seed: {seed}, deterministic: {deterministic}")
print(f"output_dir: {output_dir}")
print("=" * 80)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model.gradient_checkpointing_enable()
model.config.use_cache = False
model = prepare_model_for_kbit_training(model)
lora_target_modules = get_lora_target_modules(model)
print(f"LoRA target modules: {lora_target_modules}")
lora_config = LoraConfig(
r=int(cfg["lora_r"]),
lora_alpha=int(cfg["lora_alpha"]),
lora_dropout=float(cfg.get("lora_dropout", 0.05)),
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_target_modules,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
dataset = load_dataset("json", data_files={"train": train_file, "validation": val_file})
tokenized_train = encode_dataset(tokenizer, dataset["train"], int(cfg["max_seq_len"]))
tokenized_val = encode_dataset(tokenizer, dataset["validation"], int(cfg["max_seq_len"]))
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=int(cfg["micro_batch_size"]),
gradient_accumulation_steps=int(cfg["gradient_accumulation_steps"]),
num_train_epochs=int(cfg.get("num_epochs", 4)),
learning_rate=float(cfg["learning_rate"]),
lr_scheduler_type="cosine",
warmup_ratio=float(cfg.get("warmup_ratio", 0.03)),
max_grad_norm=float(cfg.get("max_grad_norm", 1.0)),
do_eval=True,
eval_strategy="steps",
eval_steps=int(cfg.get("eval_steps", 200)),
logging_strategy="steps",
logging_steps=int(cfg.get("logging_steps", 10)),
logging_first_step=True,
save_strategy="steps",
save_steps=int(cfg.get("save_steps", 200)),
save_total_limit=int(cfg.get("save_total_limit", 4)),
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=use_fp16,
bf16=use_bf16,
report_to=report_to_value,
run_name=run_name,
seed=seed,
data_seed=seed,
remove_unused_columns=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_val,
)
resume_from = None
if resume is not None:
if resume == "" or str(resume).lower() == "latest":
resume_from = get_latest_checkpoint(output_dir)
else:
resume_from = resume
if resume_from is None:
print(f"[Resume] No checkpoint found in {output_dir}. Start from scratch.")
else:
print(f"[Resume] Resuming from: {resume_from}")
trainer.train(resume_from_checkpoint=resume_from)
best_dir = os.path.join(output_dir, "best")
trainer.save_model(best_dir)
tokenizer.save_pretrained(best_dir)
print("Best model saved to:", best_dir)
def smoke_test_from_config(config: Dict[str, Any], max_seq_len_override: int | None = None):
cfg = resolve_training_config(config, max_seq_len_override)
seed = int(cfg.get("seed", 42))
deterministic = bool(cfg.get("deterministic", False))
set_global_seed(seed, deterministic=deterministic)
base_model = cfg.get("base_model", "Qwen/Qwen2.5-7B-Instruct")
train_file = _resolve_path(cfg.get("train_file", "data_split/train.jsonl"))
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
compute_dtype = torch.float16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model.eval()
ds_all = load_dataset("json", data_files={"train": train_file})["train"]
n = min(2, len(ds_all))
ds = ds_all.select(range(n))
tok = encode_dataset(tokenizer, ds, int(cfg["max_seq_len"]))
batch = {
"input_ids": torch.tensor([tok[i]["input_ids"] for i in range(n)], device=model.device),
"attention_mask": torch.tensor([tok[i]["attention_mask"] for i in range(n)], device=model.device),
}
with torch.no_grad():
_ = model(**batch)
print(f"Smoke test passed: tokenizer/model loaded and {n} samples forward-passed.")