"""TinyLlama Wrapper for RAG-based chatbot. This module provides a wrapper around TinyLlama model for generating responses in a RAG architecture, replacing the previous BERT-based approach. """ import time from typing import Optional import torch from loguru import logger from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, ) class TinyLlamaWrapper: """Wrapper for TinyLlama model with RAG integration support. This class provides an interface to the TinyLlama-1.1B-Chat model with support for 4-bit quantization for memory-efficient inference. """ def __init__( self, model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", use_quantization: bool = True, cache_dir: str = "models/cache", ) -> None: """Initialize the TinyLlama wrapper. Args: model_name: Hugging Face model identifier. use_quantization: Whether to use 4-bit quantization. cache_dir: Directory to cache the model files. """ self.model_name = model_name self.use_quantization = use_quantization self.cache_dir = cache_dir self.device: Optional[str] = None self.model = None self.tokenizer = None self._setup_logger() self._load_model() def _setup_logger(self) -> None: """Configure logging for the wrapper.""" logger.add( "logs/tinyllama_wrapper.log", rotation="10 MB", retention="7 days", level="INFO", ) def _load_model(self) -> None: """Load the TinyLlama model and tokenizer.""" try: has_gpu = torch.cuda.is_available() self.device = "cuda" if has_gpu else "cpu" logger.info(f"Initializing TinyLlama model: {self.model_name}") logger.info(f"GPU available: {has_gpu}, Quantization: {self.use_quantization}") quantization_config = None device_map = "auto" if has_gpu else None if has_gpu and self.use_quantization: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) logger.info("4-bit quantization enabled for GPU inference") elif not has_gpu: logger.info("Loading model on CPU with float32") self.use_quantization = False logger.info("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, cache_dir=self.cache_dir, ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("Set pad_token = eos_token") logger.info("Loading model...") self.model = AutoModelForCausalLM.from_pretrained( self.model_name, quantization_config=quantization_config, device_map=device_map, torch_dtype=torch.float16 if has_gpu else torch.float32, cache_dir=self.cache_dir, ) logger.info("Model loaded successfully on cpu") if quantization_config: logger.info("Model loaded with 4-bit quantization") except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise RuntimeError(f"Model initialization failed: {str(e)}") from e def generate( self, prompt: str, max_new_tokens: int = 180, min_new_tokens: int = 30, temperature: float = 0.3, top_p: float = 0.7, repetition_penalty: float = 1.15, early_stopping: bool = False, no_repeat_ngram_size: int = 3, ) -> str: """Generate a response from a prompt. Args: prompt: The input prompt string. max_new_tokens: Maximum number of tokens to generate. min_new_tokens: Minimum number of tokens to generate (forces at least this many). temperature: Sampling temperature (higher = more random). top_p: Nucleus sampling threshold. repetition_penalty: Penalty for repeating tokens (1.0 = no penalty). early_stopping: Whether to stop when reaching end of sentence. no_repeat_ngram_size: Prevents repeating n-grams of this size. Returns: Generated response string (without the prompt). """ start_time = time.time() try: logger.info(f"Generating response for prompt (length: {len(prompt)})") inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512, ) input_device = next(self.model.parameters()).device inputs = {k: v.to(input_device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, early_stopping=early_stopping, ) generated_text = self.tokenizer.decode( outputs[0], skip_special_tokens=True, ) response = generated_text[len(prompt):].strip() if len(response) < 20: logger.warning(f"Response too short ({len(response)} chars), returning anyway") # En lugar de fallback, devolver lo que generó if response.strip(): return response.strip() elapsed = time.time() - start_time tokens_generated = len(outputs[0]) - len(inputs["input_ids"][0]) logger.info( f"Generated {tokens_generated} tokens in {elapsed:.2f}s" ) return response except Exception as e: self._log_error(f"Generation failed: {str(e)}") return "Lo siento, hubo un problema al generar la respuesta. Por favor, intenta de nuevo." def generate_with_context( self, context: str, question: str, max_new_tokens: int = 180, ) -> str: """Generate a response given context and a question (RAG mode). Args: context: Retrieved context from the RAG system. question: User question. max_new_tokens: Maximum tokens to generate. Returns: Generated response based on the context. """ if len(context) > 600: context = context[:600] + "..." prompt = f"""<|system|> Eres un asesor de Prepa en Línea SEP. Responde en español usando solo la información del contexto. <|user|> Información: {context} Pregunta: {question} <|assistant|> """ logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...") return self.generate( prompt=prompt, max_new_tokens=max_new_tokens, temperature=0.2, top_p=0.8, repetition_penalty=1.3, no_repeat_ngram_size=3, early_stopping=True, min_new_tokens=30, ) def _log_error(self, error_msg: str) -> None: """Log an error message. Args: error_msg: The error message to log. """ logger.error(error_msg)