Spaces:
Sleeping
Sleeping
File size: 6,176 Bytes
4e5fc16 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """Model loading and inference for Francis Botcon."""
from typing import Dict, Optional
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig
)
from peft import PeftModel
from src.logger import LoggerSetup
from src.config_loader import config
logger = LoggerSetup.setup().getChild(__name__)
class FrancisModel:
"""Wrapper for Francis Botcon model."""
def __init__(
self,
model_id: str = None,
adapter_path: Optional[str] = None,
device: Optional[str] = None,
use_quantization: bool = None
):
"""Initialize Francis Botcon model.
Args:
model_id: HuggingFace model ID
adapter_path: Path to LoRA adapter (optional)
device: Device to use ('cuda', 'cpu')
use_quantization: Whether to use 4-bit quantization
"""
self.model_id = model_id or config.get("model.base_model", "meta-llama/Llama-3.2-3B-Instruct")
self.adapter_path = adapter_path
self.device = device or config.get("model.device", "cpu")
self.use_quantization = use_quantization if use_quantization is not None else config.get("model.quantization", False)
logger.info(f"Initializing Francis Botcon model")
logger.info(f" Base model: {self.model_id}")
logger.info(f" Device: {self.device}")
logger.info(f" Quantization: {self.use_quantization}")
self.tokenizer = None
self.model = None
self._load_model()
def _load_model(self):
"""Load the base model and optionally apply LoRA adapter."""
# Load tokenizer
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("✓ Tokenizer loaded")
# Configure quantization if needed
model_kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto" if self.device == "cuda" else None
}
if self.use_quantization:
logger.info("Configuring 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model_kwargs["quantization_config"] = bnb_config
# Load base model
logger.info(f"Loading base model: {self.model_id}")
self.model = AutoModelForCausalLM.from_pretrained(self.model_id, **model_kwargs)
if not self.use_quantization and self.device != "auto":
self.model = self.model.to(self.device)
logger.info("✓ Base model loaded")
# Load adapter if provided
if self.adapter_path:
logger.info(f"Loading LoRA adapter: {self.adapter_path}")
self.model = PeftModel.from_pretrained(self.model, self.adapter_path)
logger.info("✓ LoRA adapter loaded")
def generate(
self,
prompt: str,
max_length: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
do_sample: Optional[bool] = None,
**kwargs
) -> str:
"""Generate text using the model.
Args:
prompt: Input prompt
max_length: Maximum length of generated text
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
do_sample: Whether to use sampling
**kwargs: Additional generation parameters
Returns:
Generated text
"""
# Get generation config from config file if not provided
gen_config = config.get_generation_config()
max_length = max_length or gen_config.get("max_tokens", 512)
temperature = temperature if temperature is not None else gen_config.get("temperature", 0.7)
top_p = top_p or gen_config.get("top_p", 0.9)
top_k = top_k or gen_config.get("top_k", 50)
do_sample = do_sample if do_sample is not None else gen_config.get("do_sample", True)
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048
)
if self.device != "auto":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate
logger.debug("Generating text...")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id,
**kwargs
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from the output
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
return generated_text
def get_device(self) -> str:
"""Get the device the model is on.
Returns:
Device string
"""
return self.device
def __del__(self):
"""Clean up resources."""
try:
if self.model is not None:
del self.model
# Only try to empty cache if torch is still available
import torch as torch_module
if torch_module.cuda.is_available():
torch_module.cuda.empty_cache()
except (AttributeError, NameError, Exception):
# Silently ignore cleanup errors during interpreter shutdown
pass
|