import torch from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel class CodetteModelLoader: def __init__( self, base_model="meta-llama/Llama-3.1-8B-Instruct", adapters=None, ): self.base_model_name = base_model self.adapters = adapters or {} self.model = None self.tokenizer = None self.active_adapter = None self._load_base_model() def _load_base_model(self): quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) self.tokenizer = AutoTokenizer.from_pretrained( self.base_model_name, trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token base_model = AutoModelForCausalLM.from_pretrained( self.base_model_name, quantization_config=quant_config, device_map="auto", trust_remote_code=True, ) self.model = base_model def load_adapters(self): first = True for name, path in self.adapters.items(): path = str(Path(path)) if first: self.model = PeftModel.from_pretrained( self.model, path, adapter_name=name, is_trainable=False, ) self.active_adapter = name first = False else: self.model.load_adapter( path, adapter_name=name, ) def set_active_adapter(self, name): if name not in self.model.peft_config: raise ValueError(f"Adapter not loaded: {name}") self.model.set_adapter(name) self.active_adapter = name def format_messages(self, messages): return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def tokenize(self, prompt): return self.tokenizer( prompt, return_tensors="pt" ).to(self.model.device)