Spaces:
Sleeping
Sleeping
File size: 1,752 Bytes
c2af030 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training
from codeInsight.logger import logging
from codeInsight.exception import ExceptionHandle
import sys
def load_model_and_tokenizer(config : dict) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
try:
model_id = config['base_model_id']
quant_config = config['quantization']
logging.info(f"Loading base model: {model_id}")
bnb_config = BitsAndBytesConfig(
load_in_4bit=quant_config['load_in_4bit'],
bnb_4bit_quant_type=quant_config['bnb_4bit_quant_type'],
bnb_4bit_compute_dtype=quant_config['bnb_4bit_compute_dtype'],
bnb_4bit_use_double_quant=quant_config['bnb_4bit_use_double_quant']
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
attn_implementation=config['attn_implementation']
)
model.config.use_cache = False
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model.gradient_checkpointing_enable()
logging.info("Base model loaded successfully with 4-bit quantization.")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
logging.info("Tokenizer loaded successfully.")
return model, tokenizer
except Exception as e:
logging.error("Failed to load model or tokenizer")
raise ExceptionHandle(e, sys) |