Cloud / arbiter /model_utils.py
Grimxlock's picture
Arbiter: low-param (glm-edge-1.5b) config + CPU-safe training
d8cfea8 verified
"""Shared loading utilities: tokenizer, base model, LoRA adapters, 4-bit quant."""
from __future__ import annotations
import torch
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
def bnb_config(qcfg: dict) -> BitsAndBytesConfig | None:
if not qcfg.get("load_in_4bit", False):
return None
dtype = getattr(torch, qcfg.get("bnb_4bit_compute_dtype", "bfloat16"))
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=dtype,
bnb_4bit_quant_type=qcfg.get("bnb_4bit_quant_type", "nf4"),
bnb_4bit_use_double_quant=qcfg.get("bnb_4bit_use_double_quant", True),
)
def load_tokenizer(model_cfg: dict):
tok = AutoTokenizer.from_pretrained(
model_cfg["base_model"],
trust_remote_code=model_cfg.get("trust_remote_code", True),
)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
return tok
def load_base_model(model_cfg: dict, qcfg: dict, training: bool = True):
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_cfg["base_model"],
trust_remote_code=model_cfg.get("trust_remote_code", True),
torch_dtype=dtype,
quantization_config=bnb_config(qcfg),
device_map="auto",
)
if training and qcfg.get("load_in_4bit", False):
model = prepare_model_for_kbit_training(model)
model.config.use_cache = not training
return model
def attach_lora(model, lora_cfg: dict):
cfg = LoraConfig(
r=lora_cfg["r"],
lora_alpha=lora_cfg["alpha"],
lora_dropout=lora_cfg["dropout"],
target_modules=lora_cfg["target_modules"],
bias="none",
task_type="CAUSAL_LM",
)
return get_peft_model(model, cfg)
def load_with_adapter(model_cfg: dict, qcfg: dict, adapter_dir: str, training: bool = True):
base = load_base_model(model_cfg, qcfg, training=training)
return PeftModel.from_pretrained(base, adapter_dir, is_trainable=training)