Spaces:
Sleeping
Sleeping
| """Model loading and inference for Francis Botcon.""" | |
| from typing import Dict, Optional | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| BitsAndBytesConfig | |
| ) | |
| from peft import PeftModel | |
| from src.logger import LoggerSetup | |
| from src.config_loader import config | |
| logger = LoggerSetup.setup().getChild(__name__) | |
| class FrancisModel: | |
| """Wrapper for Francis Botcon model.""" | |
| def __init__( | |
| self, | |
| model_id: str = None, | |
| adapter_path: Optional[str] = None, | |
| device: Optional[str] = None, | |
| use_quantization: bool = None | |
| ): | |
| """Initialize Francis Botcon model. | |
| Args: | |
| model_id: HuggingFace model ID | |
| adapter_path: Path to LoRA adapter (optional) | |
| device: Device to use ('cuda', 'cpu') | |
| use_quantization: Whether to use 4-bit quantization | |
| """ | |
| self.model_id = model_id or config.get("model.base_model", "meta-llama/Llama-3.2-3B-Instruct") | |
| self.adapter_path = adapter_path | |
| self.device = device or config.get("model.device", "cpu") | |
| self.use_quantization = use_quantization if use_quantization is not None else config.get("model.quantization", False) | |
| logger.info(f"Initializing Francis Botcon model") | |
| logger.info(f" Base model: {self.model_id}") | |
| logger.info(f" Device: {self.device}") | |
| logger.info(f" Quantization: {self.use_quantization}") | |
| self.tokenizer = None | |
| self.model = None | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the base model and optionally apply LoRA adapter.""" | |
| # Load tokenizer | |
| logger.info("Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| logger.info("✓ Tokenizer loaded") | |
| # Configure quantization if needed | |
| model_kwargs = { | |
| "torch_dtype": torch.float16, | |
| "device_map": "auto" if self.device == "cuda" else None | |
| } | |
| if self.use_quantization: | |
| logger.info("Configuring 4-bit quantization...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| model_kwargs["quantization_config"] = bnb_config | |
| # Load base model | |
| logger.info(f"Loading base model: {self.model_id}") | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs) | |
| if not self.use_quantization and self.device != "auto": | |
| self.model = self.model.to(self.device) | |
| logger.info("✓ Base model loaded") | |
| # Load adapter if provided | |
| if self.adapter_path: | |
| logger.info(f"Loading LoRA adapter: {self.adapter_path}") | |
| self.model = PeftModel.from_pretrained(self.model, self.adapter_path) | |
| logger.info("✓ LoRA adapter loaded") | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_length: Optional[int] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| do_sample: Optional[bool] = None, | |
| **kwargs | |
| ) -> str: | |
| """Generate text using the model. | |
| Args: | |
| prompt: Input prompt | |
| max_length: Maximum length of generated text | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| top_k: Top-k sampling parameter | |
| do_sample: Whether to use sampling | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| Generated text | |
| """ | |
| # Get generation config from config file if not provided | |
| gen_config = config.get_generation_config() | |
| max_length = max_length or gen_config.get("max_tokens", 512) | |
| temperature = temperature if temperature is not None else gen_config.get("temperature", 0.7) | |
| top_p = top_p or gen_config.get("top_p", 0.9) | |
| top_k = top_k or gen_config.get("top_k", 50) | |
| do_sample = do_sample if do_sample is not None else gen_config.get("do_sample", True) | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048 | |
| ) | |
| if self.device != "auto": | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Generate | |
| logger.debug("Generating text...") | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=do_sample, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| **kwargs | |
| ) | |
| # Decode | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the prompt from the output | |
| if generated_text.startswith(prompt): | |
| generated_text = generated_text[len(prompt):].strip() | |
| return generated_text | |
| def get_device(self) -> str: | |
| """Get the device the model is on. | |
| Returns: | |
| Device string | |
| """ | |
| return self.device | |
| def __del__(self): | |
| """Clean up resources.""" | |
| try: | |
| if self.model is not None: | |
| del self.model | |
| # Only try to empty cache if torch is still available | |
| import torch as torch_module | |
| if torch_module.cuda.is_available(): | |
| torch_module.cuda.empty_cache() | |
| except (AttributeError, NameError, Exception): | |
| # Silently ignore cleanup errors during interpreter shutdown | |
| pass | |