Codette-Reasoning / inference /model_loader.py
Raiff1982's picture
Upload 120 files
ed1b365 verified
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
class CodetteModelLoader:
def __init__(
self,
base_model="meta-llama/Llama-3.1-8B-Instruct",
adapters=None,
):
self.base_model_name = base_model
self.adapters = adapters or {}
self.model = None
self.tokenizer = None
self.active_adapter = None
self._load_base_model()
def _load_base_model(self):
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
quantization_config=quant_config,
device_map="auto",
trust_remote_code=True,
)
self.model = base_model
def load_adapters(self):
first = True
for name, path in self.adapters.items():
path = str(Path(path))
if first:
self.model = PeftModel.from_pretrained(
self.model,
path,
adapter_name=name,
is_trainable=False,
)
self.active_adapter = name
first = False
else:
self.model.load_adapter(
path,
adapter_name=name,
)
def set_active_adapter(self, name):
if name not in self.model.peft_config:
raise ValueError(f"Adapter not loaded: {name}")
self.model.set_adapter(name)
self.active_adapter = name
def format_messages(self, messages):
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
def tokenize(self, prompt):
return self.tokenizer(
prompt,
return_tensors="pt"
).to(self.model.device)