"""Model loading and inference for Francis Botcon.""" from typing import Dict, Optional import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig ) from peft import PeftModel from src.logger import LoggerSetup from src.config_loader import config logger = LoggerSetup.setup().getChild(__name__) class FrancisModel: """Wrapper for Francis Botcon model.""" def __init__( self, model_id: str = None, adapter_path: Optional[str] = None, device: Optional[str] = None, use_quantization: bool = None ): """Initialize Francis Botcon model. Args: model_id: HuggingFace model ID adapter_path: Path to LoRA adapter (optional) device: Device to use ('cuda', 'cpu') use_quantization: Whether to use 4-bit quantization """ self.model_id = model_id or config.get("model.base_model", "meta-llama/Llama-3.2-3B-Instruct") self.adapter_path = adapter_path self.device = device or config.get("model.device", "cpu") self.use_quantization = use_quantization if use_quantization is not None else config.get("model.quantization", False) logger.info(f"Initializing Francis Botcon model") logger.info(f" Base model: {self.model_id}") logger.info(f" Device: {self.device}") logger.info(f" Quantization: {self.use_quantization}") self.tokenizer = None self.model = None self._load_model() def _load_model(self): """Load the base model and optionally apply LoRA adapter.""" # Load tokenizer logger.info("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("✓ Tokenizer loaded") # Configure quantization if needed model_kwargs = { "torch_dtype": torch.float16, "device_map": "auto" if self.device == "cuda" else None } if self.use_quantization: logger.info("Configuring 4-bit quantization...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model_kwargs["quantization_config"] = bnb_config # Load base model logger.info(f"Loading base model: {self.model_id}") self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs) if not self.use_quantization and self.device != "auto": self.model = self.model.to(self.device) logger.info("✓ Base model loaded") # Load adapter if provided if self.adapter_path: logger.info(f"Loading LoRA adapter: {self.adapter_path}") self.model = PeftModel.from_pretrained(self.model, self.adapter_path) logger.info("✓ LoRA adapter loaded") def generate( self, prompt: str, max_length: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, do_sample: Optional[bool] = None, **kwargs ) -> str: """Generate text using the model. Args: prompt: Input prompt max_length: Maximum length of generated text temperature: Sampling temperature top_p: Nucleus sampling parameter top_k: Top-k sampling parameter do_sample: Whether to use sampling **kwargs: Additional generation parameters Returns: Generated text """ # Get generation config from config file if not provided gen_config = config.get_generation_config() max_length = max_length or gen_config.get("max_tokens", 512) temperature = temperature if temperature is not None else gen_config.get("temperature", 0.7) top_p = top_p or gen_config.get("top_p", 0.9) top_k = top_k or gen_config.get("top_k", 50) do_sample = do_sample if do_sample is not None else gen_config.get("do_sample", True) # Tokenize input inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 ) if self.device != "auto": inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate logger.debug("Generating text...") with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id, **kwargs ) # Decode generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the prompt from the output if generated_text.startswith(prompt): generated_text = generated_text[len(prompt):].strip() return generated_text def get_device(self) -> str: """Get the device the model is on. Returns: Device string """ return self.device def __del__(self): """Clean up resources.""" try: if self.model is not None: del self.model # Only try to empty cache if torch is still available import torch as torch_module if torch_module.cuda.is_available(): torch_module.cuda.empty_cache() except (AttributeError, NameError, Exception): # Silently ignore cleanup errors during interpreter shutdown pass