Chatbot-RAG-v4 / models /tinyllama_wrapper.py
NoeMartinezSanchez
Se ajustaron nuevamente los parametros por que la respuesta era muy corta
65fb599
"""TinyLlama Wrapper for RAG-based chatbot.
This module provides a wrapper around TinyLlama model for generating
responses in a RAG architecture, replacing the previous BERT-based approach.
"""
import time
from typing import Optional
import torch
from loguru import logger
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
class TinyLlamaWrapper:
"""Wrapper for TinyLlama model with RAG integration support.
This class provides an interface to the TinyLlama-1.1B-Chat model with
support for 4-bit quantization for memory-efficient inference.
"""
def __init__(
self,
model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
use_quantization: bool = True,
cache_dir: str = "models/cache",
) -> None:
"""Initialize the TinyLlama wrapper.
Args:
model_name: Hugging Face model identifier.
use_quantization: Whether to use 4-bit quantization.
cache_dir: Directory to cache the model files.
"""
self.model_name = model_name
self.use_quantization = use_quantization
self.cache_dir = cache_dir
self.device: Optional[str] = None
self.model = None
self.tokenizer = None
self._setup_logger()
self._load_model()
def _setup_logger(self) -> None:
"""Configure logging for the wrapper."""
logger.add(
"logs/tinyllama_wrapper.log",
rotation="10 MB",
retention="7 days",
level="INFO",
)
def _load_model(self) -> None:
"""Load the TinyLlama model and tokenizer."""
try:
has_gpu = torch.cuda.is_available()
self.device = "cuda" if has_gpu else "cpu"
logger.info(f"Initializing TinyLlama model: {self.model_name}")
logger.info(f"GPU available: {has_gpu}, Quantization: {self.use_quantization}")
quantization_config = None
device_map = "auto" if has_gpu else None
if has_gpu and self.use_quantization:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
logger.info("4-bit quantization enabled for GPU inference")
elif not has_gpu:
logger.info("Loading model on CPU with float32")
self.use_quantization = False
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("Set pad_token = eos_token")
logger.info("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.float16 if has_gpu else torch.float32,
cache_dir=self.cache_dir,
)
logger.info("Model loaded successfully on cpu")
if quantization_config:
logger.info("Model loaded with 4-bit quantization")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise RuntimeError(f"Model initialization failed: {str(e)}") from e
def generate(
self,
prompt: str,
max_new_tokens: int = 180,
min_new_tokens: int = 30,
temperature: float = 0.3,
top_p: float = 0.7,
repetition_penalty: float = 1.15,
early_stopping: bool = False,
no_repeat_ngram_size: int = 3,
) -> str:
"""Generate a response from a prompt.
Args:
prompt: The input prompt string.
max_new_tokens: Maximum number of tokens to generate.
min_new_tokens: Minimum number of tokens to generate (forces at least this many).
temperature: Sampling temperature (higher = more random).
top_p: Nucleus sampling threshold.
repetition_penalty: Penalty for repeating tokens (1.0 = no penalty).
early_stopping: Whether to stop when reaching end of sentence.
no_repeat_ngram_size: Prevents repeating n-grams of this size.
Returns:
Generated response string (without the prompt).
"""
start_time = time.time()
try:
logger.info(f"Generating response for prompt (length: {len(prompt)})")
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512,
)
input_device = next(self.model.parameters()).device
inputs = {k: v.to(input_device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
early_stopping=early_stopping,
)
generated_text = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True,
)
response = generated_text[len(prompt):].strip()
if len(response) < 20:
logger.warning(f"Response too short ({len(response)} chars), returning anyway")
# En lugar de fallback, devolver lo que generó
if response.strip():
return response.strip()
elapsed = time.time() - start_time
tokens_generated = len(outputs[0]) - len(inputs["input_ids"][0])
logger.info(
f"Generated {tokens_generated} tokens in {elapsed:.2f}s"
)
return response
except Exception as e:
self._log_error(f"Generation failed: {str(e)}")
return "Lo siento, hubo un problema al generar la respuesta. Por favor, intenta de nuevo."
def generate_with_context(
self,
context: str,
question: str,
max_new_tokens: int = 180,
) -> str:
"""Generate a response given context and a question (RAG mode).
Args:
context: Retrieved context from the RAG system.
question: User question.
max_new_tokens: Maximum tokens to generate.
Returns:
Generated response based on the context.
"""
if len(context) > 600:
context = context[:600] + "..."
prompt = f"""<|system|>
Eres un asesor de Prepa en Línea SEP. Responde en español usando solo la información del contexto.
<|user|>
Información: {context}
Pregunta: {question}
<|assistant|>
"""
logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
return self.generate(
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=0.2,
top_p=0.8,
repetition_penalty=1.3,
no_repeat_ngram_size=3,
early_stopping=True,
min_new_tokens=30,
)
def _log_error(self, error_msg: str) -> None:
"""Log an error message.
Args:
error_msg: The error message to log.
"""
logger.error(error_msg)