"""Phi-2 Wrapper for RAG-based chatbot. This module provides a wrapper around the microsoft/phi-2 model for generating responses in a RAG architecture, optimized for CPU-only inference. """ import gc import os import sys import time from typing import Optional, Callable import torch import transformers from loguru import logger from requests.adapters import HTTPAdapter from transformers import AutoModelForCausalLM, AutoTokenizer from urllib3.util.retry import Retry MODEL_VARIANTS = [ "google/gemma-4-e4b-it", # 🚀 Principal: rápido y potente "google/gemma-4-e2b-it", # ⚡ Backup: más ligero "google/gemma-2-2b-it", # 🔄 Fallback final ] MIN_TRANSFORMERS_VERSION = "4.51.0" def _check_transformers_version(): """Verify transformers version meets minimum requirement for Gemma 4.""" installed = transformers.__version__ major, minor, _ = installed.split(".")[:3] required_major, required_minor = MIN_TRANSFORMERS_VERSION.split(".")[:2] if (int(major) < int(required_major) or (int(major) == int(required_major) and int(minor) < int(required_minor))): logger.warning( f"⚠️ transformers {installed} incompatible con gemma-4. " f"Requiere >= {MIN_TRANSFORMERS_VERSION}. " f"Se usará fallback gemma-2-2b-it" ) else: logger.info(f"✅ transformers version: {installed} (compatible con gemma-4)") class GemmaWrapper: """Wrapper for Gemma-4 E4B model with RAG integration support. This class provides an interface to the Gemma-4-e4b-it model optimized for CPU-only inference. """ def __init__( self, model_name: str = "google/gemma-2-2b-it", cache_dir: str = "models/cache", ) -> None: """Initialize the Gemma wrapper. Args: model_name: Hugging Face model identifier. cache_dir: Directory to cache the model files. """ _check_transformers_version() self.model_name = model_name self.cache_dir = cache_dir self.device = "cpu" self.model = None self.tokenizer = None self._setup_logger() self._load_model() def _setup_logger(self) -> None: """Configure logging for the wrapper.""" logger.add( "logs/gemma_wrapper.log", rotation="10 MB", retention="7 days", level="INFO", ) def _load_model(self) -> None: """Load the Gemma model and tokenizer with fallback strategy.""" logger.info("=" * 60) logger.info("🚀 INICIANDO CARGA DEL MODELO GEMMA") logger.info("=" * 60) hf_token = os.getenv("HF_TOKEN") if hf_token: logger.info("✅ HF_TOKEN encontrado en variables de entorno") else: logger.warning("⚠️ HF_TOKEN no encontrado en variables de entorno") logger.info(" (El modelo debe ser público o tener HF_TOKEN configurado)") loaded = False last_error = None for i, model_variant in enumerate(MODEL_VARIANTS): if i > 0: logger.info(f"🔄 Intentando fallback: {model_variant}") try: logger.info(f"📥 [1/4] Descargando tokenizer de: {model_variant}") download_start = time.time() self.tokenizer = AutoTokenizer.from_pretrained( model_variant, cache_dir=self.cache_dir, token=hf_token, trust_remote_code=True, ) logger.info(f" ✅ Tokenizer descargado en {time.time() - download_start:.1f}s") if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info(" ✅ pad_token = eos_token (configurado)") else: logger.info(f" ✅ pad_token ya configurado: {self.tokenizer.pad_token}") logger.info(f"📥 [2/4] Descargando modelo: {model_variant}") logger.info(" ℹ️ Tamaño: ~5-8 GB, puede tomar varios minutos...") model_start = time.time() # Configuración especial para Gemma 4 if "gemma-4" in model_variant: logger.info(" ℹ️ Configuración Gemma 4: float16, device_map=cpu") self.model = AutoModelForCausalLM.from_pretrained( model_variant, device_map="cpu", torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True, cache_dir=self.cache_dir, token=hf_token, ) else: # Configuración original para Gemma 2 logger.info(" ℹ️ Configuración: device_map=cpu, torch_dtype=float32") self.model = AutoModelForCausalLM.from_pretrained( model_variant, device_map="cpu", torch_dtype=torch.float32, cache_dir=self.cache_dir, token=hf_token, trust_remote_code=True, ) model_time = time.time() - model_start logger.info(f" ✅ Modelo descargado en {model_time:.1f}s") logger.info(" ℹ️ Ejecutando model.eval()...") self.model.eval() self.model_name = model_variant loaded = True logger.info("=" * 60) logger.info(f"✅ MODELO CARGADO EXITOSAMENTE: {self.model_name}") logger.info(f" 📍 Device: CPU (float32)") logger.info(f" 💾 Memoria aproximada: ~5-6 GB RAM") logger.info(f" ⏱️ Tiempo total: {model_time:.1f}s") logger.info("=" * 60) break except KeyError as e: logger.error(f"❌ KeyError con {model_variant}: {e}") logger.info(" → Intentando siguiente variante...") last_error = e continue except TypeError as e: if "timeout" in str(e): logger.error(f"❌ Timeout con {model_variant}: {e}") else: logger.error(f"❌ TypeError con {model_variant}: {e}") logger.info(" → Intentando siguiente variante...") last_error = e continue except Exception as e: logger.error(f"❌ Error cargando {model_variant}: {str(e)}") logger.info(" → Intentando siguiente variante...") last_error = e continue if not loaded: error_msg = ( f"❌ Falló la carga de todas las variantes de Gemma. " f"Último error: {last_error}. " f"Modelos intentados: {MODEL_VARIANTS}. " f"Asegúrate de tener transformers>={MIN_TRANSFORMERS_VERSION}" ) logger.error(error_msg) raise RuntimeError(error_msg) def generate( self, prompt: str, max_new_tokens: int = 50, min_new_tokens: int = 10, temperature: float = 0.3, top_p: float = 0.85, repetition_penalty: float = 1.1, early_stopping: bool = True, no_repeat_ngram_size: int = 3, on_tokens_generated: Optional[Callable[[int, float], None]] = None, ) -> str: """Generate a response from a prompt. Args: prompt: The input prompt string in Gemma chat format. max_new_tokens: Maximum number of tokens to generate. min_new_tokens: Minimum number of tokens to generate. temperature: Sampling temperature (higher = more random). top_p: Nucleus sampling threshold. repetition_penalty: Penalty for repeating tokens (1.0 = no penalty). early_stopping: Whether to stop when reaching end of sentence. no_repeat_ngram_size: Prevents repeating n-grams of this size. on_tokens_generated: Callback to report tokens and elapsed time. Returns: Generated response string (without the prompt). """ start_time = time.time() try: logger.info(f"Generating response for prompt (length: {len(prompt)})") inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=1024, ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( **inputs, do_sample=True, temperature=temperature, top_p=top_p, top_k=40, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, early_stopping=early_stopping, ) generated_text = self.tokenizer.decode( outputs[0], skip_special_tokens=True, ) response = generated_text[len(prompt):].strip() response = self._clean_response(response) response = self.fix_common_errors(response) if len(response) < 10: logger.warning(f"Response very short ({len(response)} chars)") return response.strip() if response.strip() else "No se pudo generar una respuesta." elapsed = time.time() - start_time tokens_generated = len(outputs[0]) - len(inputs["input_ids"][0]) logger.info( f"Generated {tokens_generated} tokens in {elapsed:.2f}s " f"({tokens_generated/elapsed:.1f} tokens/s)" ) if on_tokens_generated: on_tokens_generated(tokens_generated, elapsed) self._clear_cache() return response except Exception as e: logger.error(f"Generation failed: {str(e)}") self._clear_cache() return "Lo siento, hubo un problema al generar la respuesta. Por favor, intenta de nuevo." def generate_with_context( self, context: str, question: str, max_new_tokens: int = 60, on_tokens_generated: Optional[Callable[[int, float], None]] = None, ) -> str: """Generate a response given context and a question (RAG mode). Args: context: Retrieved context from the RAG system. question: User question. max_new_tokens: Maximum tokens to generate. on_tokens_generated: Callback(token_count, elapsed_seconds). Returns: Generated response based on the context. """ prompt = self._build_simple_prompt(context, question) logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...") return self.generate( prompt=prompt, max_new_tokens=100, # Respuestas más largas min_new_tokens=15, temperature=0.2, top_p=0.85, repetition_penalty=1.1, no_repeat_ngram_size=3, on_tokens_generated=on_tokens_generated, ) def _build_simple_prompt(self, context: str, question: str) -> str: """Build a prompt for Gemma optimized for RAG with short responses.""" PROMPT_TEMPLATE = """Responde la pregunta usando SOLO el texto que está entre === CONTEXTO === y === FIN CONTEXTO ===. === CONTEXTO === {context} === FIN CONTEXTO === PREGUNTA: {question} REGLAS: 1. Responde con UNA sola frase corta 2. Empieza con "Sí" o "No" si es pregunta de sí/no 3. Si la respuesta no está en el contexto, di: "No encontré esa información" RESPUESTA:""" prompt = PROMPT_TEMPLATE.format(context=context, question=question) return f"user\n{prompt}\nmodel\n" def _clean_response(self, text: str) -> str: if not text: return text import re text = re.sub(r'^[^a-zA-ZáéíóúÁÉÍÓÚ¿¡]+', '', text) words = text.split() if len(words) > 1 and len(words[0]) <= 2: text = ' '.join(words[1:]) text = re.sub(r'\s+', ' ', text).strip() if text and text[0].islower(): text = text[0].upper() + text[1:] return text def fix_common_errors(self, text: str) -> str: replacements = { "constatancia": "constancia", "constatancoa": "constancia", "secondary": "secundaria", "otografía": "fotografía", "credenciación": "credencial", "cartascompromiso": "carta compromiso", "carta compromiso ": "carta compromiso ", } for wrong, correct in replacements.items(): text = text.replace(wrong, correct) return text def _clear_cache(self) -> None: """Clear Python and PyTorch garbage and cache.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("Cleared memory cache") def get_model_info(self) -> dict: """Get information about the loaded model. Returns: Dictionary with model metadata. """ return { "model_name": self.model_name, "device": self.device, "dtype": "float32", "parameters": "2B", "quantization": "none", }