adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
Raw
History Blame Contribute Delete
1.84 kB
# load qwen2.5-0.5b-instruct, apply lora, and pick the right precision for the
# detected device. cpu path is reserved for the smoke test.
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from cleanup.config import TrainConfig
def _resolve_dtype(cfg: TrainConfig):
if not torch.cuda.is_available():
return torch.float32
if cfg.bf16 and torch.cuda.is_bf16_supported():
return torch.bfloat16
if cfg.fp16:
return torch.float16
return torch.float32
def load_base_and_tokenizer(cfg: TrainConfig):
tokenizer = AutoTokenizer.from_pretrained(cfg.base_model, use_fast=True)
# qwen ships with a pad token; if missing, fall back to eos so the
# collator does not throw on padding.
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# left padding for causal lm decoding is fine for training too; sftrainer
# handles batching with attention masks.
tokenizer.padding_side = "right"
dtype = _resolve_dtype(cfg)
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
)
# qwen does not enable gradient checkpointing by default; turning it on
# saves vram and the trainer recompiles forward to honor it.
model.config.use_cache = False
return model, tokenizer
def wrap_with_lora(model, cfg: TrainConfig):
lora_config = LoraConfig(
r=cfg.lora.r,
lora_alpha=cfg.lora.alpha,
lora_dropout=cfg.lora.dropout,
bias=cfg.lora.bias,
target_modules=cfg.lora.target_modules,
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model