precis / init /model.py
compendious's picture
Output rendering, documentation, model readiness UI
851f234
Raw
History Blame Contribute Delete
3.02 kB
"""Model loading utilities for Précis."""
import logging
from typing import Optional, Tuple
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PreTrainedModel,
PreTrainedTokenizer,
)
from src.config import ModelConfig
logger = logging.getLogger(__name__)
def get_quantization_config(config: ModelConfig) -> Optional[BitsAndBytesConfig]:
"""Create BitsAndBytes quantization configuration."""
if config.load_in_4bit:
compute_dtype = getattr(torch, config.bnb_4bit_compute_dtype)
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type=config.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=config.bnb_4bit_use_double_quant,
)
elif config.load_in_8bit:
return BitsAndBytesConfig(load_in_8bit=True)
return None
def load_tokenizer(config: Optional[ModelConfig] = None) -> PreTrainedTokenizer:
"""Load and configure the tokenizer."""
if config is None:
config = ModelConfig()
logger.info(f"Loading tokenizer: {config.model_id}")
tokenizer = AutoTokenizer.from_pretrained(
config.model_id,
trust_remote_code=config.trust_remote_code,
cache_dir=config.cache_dir,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"
return tokenizer
def load_model(config: Optional[ModelConfig] = None) -> PreTrainedModel:
"""Load the base model with optional quantization."""
if config is None:
config = ModelConfig()
logger.info(f"Loading model: {config.model_id}")
quantization_config = get_quantization_config(config)
model = AutoModelForCausalLM.from_pretrained(
config.model_id,
quantization_config=quantization_config,
device_map=config.device_map,
trust_remote_code=config.trust_remote_code,
cache_dir=config.cache_dir,
torch_dtype=torch.float16 if quantization_config else "auto",
)
logger.info(f"Model loaded. Parameters: {model.num_parameters():,}")
return model
def prepare_for_training(model: PreTrainedModel, gradient_checkpointing: bool = True) -> PreTrainedModel:
"""Prepare model for training with gradient checkpointing and k-bit setup."""
if gradient_checkpointing:
model.gradient_checkpointing_enable()
if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)
return model
def load_model_and_tokenizer(config: Optional[ModelConfig] = None) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""Load both model and tokenizer."""
if config is None:
config = ModelConfig()
return load_model(config), load_tokenizer(config)