francis-botcon / src /model.py
Rojaldo
Initialize Francis Botcon Gradio Space with model files
4e5fc16
"""Model loading and inference for Francis Botcon."""
from typing import Dict, Optional
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig
)
from peft import PeftModel
from src.logger import LoggerSetup
from src.config_loader import config
logger = LoggerSetup.setup().getChild(__name__)
class FrancisModel:
"""Wrapper for Francis Botcon model."""
def __init__(
self,
model_id: str = None,
adapter_path: Optional[str] = None,
device: Optional[str] = None,
use_quantization: bool = None
):
"""Initialize Francis Botcon model.
Args:
model_id: HuggingFace model ID
adapter_path: Path to LoRA adapter (optional)
device: Device to use ('cuda', 'cpu')
use_quantization: Whether to use 4-bit quantization
"""
self.model_id = model_id or config.get("model.base_model", "meta-llama/Llama-3.2-3B-Instruct")
self.adapter_path = adapter_path
self.device = device or config.get("model.device", "cpu")
self.use_quantization = use_quantization if use_quantization is not None else config.get("model.quantization", False)
logger.info(f"Initializing Francis Botcon model")
logger.info(f" Base model: {self.model_id}")
logger.info(f" Device: {self.device}")
logger.info(f" Quantization: {self.use_quantization}")
self.tokenizer = None
self.model = None
self._load_model()
def _load_model(self):
"""Load the base model and optionally apply LoRA adapter."""
# Load tokenizer
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("✓ Tokenizer loaded")
# Configure quantization if needed
model_kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto" if self.device == "cuda" else None
}
if self.use_quantization:
logger.info("Configuring 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model_kwargs["quantization_config"] = bnb_config
# Load base model
logger.info(f"Loading base model: {self.model_id}")
self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
if not self.use_quantization and self.device != "auto":
self.model = self.model.to(self.device)
logger.info("✓ Base model loaded")
# Load adapter if provided
if self.adapter_path:
logger.info(f"Loading LoRA adapter: {self.adapter_path}")
self.model = PeftModel.from_pretrained(self.model, self.adapter_path)
logger.info("✓ LoRA adapter loaded")
def generate(
self,
prompt: str,
max_length: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
do_sample: Optional[bool] = None,
**kwargs
) -> str:
"""Generate text using the model.
Args:
prompt: Input prompt
max_length: Maximum length of generated text
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
do_sample: Whether to use sampling
**kwargs: Additional generation parameters
Returns:
Generated text
"""
# Get generation config from config file if not provided
gen_config = config.get_generation_config()
max_length = max_length or gen_config.get("max_tokens", 512)
temperature = temperature if temperature is not None else gen_config.get("temperature", 0.7)
top_p = top_p or gen_config.get("top_p", 0.9)
top_k = top_k or gen_config.get("top_k", 50)
do_sample = do_sample if do_sample is not None else gen_config.get("do_sample", True)
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048
)
if self.device != "auto":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate
logger.debug("Generating text...")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id,
**kwargs
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from the output
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
return generated_text
def get_device(self) -> str:
"""Get the device the model is on.
Returns:
Device string
"""
return self.device
def __del__(self):
"""Clean up resources."""
try:
if self.model is not None:
del self.model
# Only try to empty cache if torch is still available
import torch as torch_module
if torch_module.cuda.is_available():
torch_module.cuda.empty_cache()
except (AttributeError, NameError, Exception):
# Silently ignore cleanup errors during interpreter shutdown
pass