"""Shared loading utilities: tokenizer, base model, LoRA adapters, 4-bit quant.""" from __future__ import annotations import torch from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig def bnb_config(qcfg: dict) -> BitsAndBytesConfig | None: if not qcfg.get("load_in_4bit", False): return None dtype = getattr(torch, qcfg.get("bnb_4bit_compute_dtype", "bfloat16")) return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=dtype, bnb_4bit_quant_type=qcfg.get("bnb_4bit_quant_type", "nf4"), bnb_4bit_use_double_quant=qcfg.get("bnb_4bit_use_double_quant", True), ) def load_tokenizer(model_cfg: dict): tok = AutoTokenizer.from_pretrained( model_cfg["base_model"], trust_remote_code=model_cfg.get("trust_remote_code", True), ) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.padding_side = "left" return tok def load_base_model(model_cfg: dict, qcfg: dict, training: bool = True): dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 model = AutoModelForCausalLM.from_pretrained( model_cfg["base_model"], trust_remote_code=model_cfg.get("trust_remote_code", True), torch_dtype=dtype, quantization_config=bnb_config(qcfg), device_map="auto", ) if training and qcfg.get("load_in_4bit", False): model = prepare_model_for_kbit_training(model) model.config.use_cache = not training return model def attach_lora(model, lora_cfg: dict): cfg = LoraConfig( r=lora_cfg["r"], lora_alpha=lora_cfg["alpha"], lora_dropout=lora_cfg["dropout"], target_modules=lora_cfg["target_modules"], bias="none", task_type="CAUSAL_LM", ) return get_peft_model(model, cfg) def load_with_adapter(model_cfg: dict, qcfg: dict, adapter_dir: str, training: bool = True): base = load_base_model(model_cfg, qcfg, training=training) return PeftModel.from_pretrained(base, adapter_dir, is_trainable=training)