| """Shared loading utilities: tokenizer, base model, LoRA adapters, 4-bit quant.""" |
| from __future__ import annotations |
|
|
| import torch |
| from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
|
| def bnb_config(qcfg: dict) -> BitsAndBytesConfig | None: |
| if not qcfg.get("load_in_4bit", False): |
| return None |
| dtype = getattr(torch, qcfg.get("bnb_4bit_compute_dtype", "bfloat16")) |
| return BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=dtype, |
| bnb_4bit_quant_type=qcfg.get("bnb_4bit_quant_type", "nf4"), |
| bnb_4bit_use_double_quant=qcfg.get("bnb_4bit_use_double_quant", True), |
| ) |
|
|
|
|
| def load_tokenizer(model_cfg: dict): |
| tok = AutoTokenizer.from_pretrained( |
| model_cfg["base_model"], |
| trust_remote_code=model_cfg.get("trust_remote_code", True), |
| ) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| tok.padding_side = "left" |
| return tok |
|
|
|
|
| def load_base_model(model_cfg: dict, qcfg: dict, training: bool = True): |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| model = AutoModelForCausalLM.from_pretrained( |
| model_cfg["base_model"], |
| trust_remote_code=model_cfg.get("trust_remote_code", True), |
| torch_dtype=dtype, |
| quantization_config=bnb_config(qcfg), |
| device_map="auto", |
| ) |
| if training and qcfg.get("load_in_4bit", False): |
| model = prepare_model_for_kbit_training(model) |
| model.config.use_cache = not training |
| return model |
|
|
|
|
| def attach_lora(model, lora_cfg: dict): |
| cfg = LoraConfig( |
| r=lora_cfg["r"], |
| lora_alpha=lora_cfg["alpha"], |
| lora_dropout=lora_cfg["dropout"], |
| target_modules=lora_cfg["target_modules"], |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| return get_peft_model(model, cfg) |
|
|
|
|
| def load_with_adapter(model_cfg: dict, qcfg: dict, adapter_dir: str, training: bool = True): |
| base = load_base_model(model_cfg, qcfg, training=training) |
| return PeftModel.from_pretrained(base, adapter_dir, is_trainable=training) |
|
|