Chatbot-RAG-v4 / models /gemma_wrapper.py
NoeMartinezSanchez
Fix Gemma 4 with transformers==4.46.3
3165153
"""Phi-2 Wrapper for RAG-based chatbot.
This module provides a wrapper around the microsoft/phi-2 model for generating
responses in a RAG architecture, optimized for CPU-only inference.
"""
import gc
import os
import sys
import time
from typing import Optional, Callable
import torch
import transformers
from loguru import logger
from requests.adapters import HTTPAdapter
from transformers import AutoModelForCausalLM, AutoTokenizer
from urllib3.util.retry import Retry
MODEL_VARIANTS = [
"google/gemma-4-e4b-it", # 🚀 Principal: rápido y potente
"google/gemma-4-e2b-it", # ⚡ Backup: más ligero
"google/gemma-2-2b-it", # 🔄 Fallback final
]
MIN_TRANSFORMERS_VERSION = "4.51.0"
def _check_transformers_version():
"""Verify transformers version meets minimum requirement for Gemma 4."""
installed = transformers.__version__
major, minor, _ = installed.split(".")[:3]
required_major, required_minor = MIN_TRANSFORMERS_VERSION.split(".")[:2]
if (int(major) < int(required_major) or
(int(major) == int(required_major) and int(minor) < int(required_minor))):
logger.warning(
f"⚠️ transformers {installed} incompatible con gemma-4. "
f"Requiere >= {MIN_TRANSFORMERS_VERSION}. "
f"Se usará fallback gemma-2-2b-it"
)
else:
logger.info(f"✅ transformers version: {installed} (compatible con gemma-4)")
class GemmaWrapper:
"""Wrapper for Gemma-4 E4B model with RAG integration support.
This class provides an interface to the Gemma-4-e4b-it model optimized for
CPU-only inference.
"""
def __init__(
self,
model_name: str = "google/gemma-2-2b-it",
cache_dir: str = "models/cache",
) -> None:
"""Initialize the Gemma wrapper.
Args:
model_name: Hugging Face model identifier.
cache_dir: Directory to cache the model files.
"""
_check_transformers_version()
self.model_name = model_name
self.cache_dir = cache_dir
self.device = "cpu"
self.model = None
self.tokenizer = None
self._setup_logger()
self._load_model()
def _setup_logger(self) -> None:
"""Configure logging for the wrapper."""
logger.add(
"logs/gemma_wrapper.log",
rotation="10 MB",
retention="7 days",
level="INFO",
)
def _load_model(self) -> None:
"""Load the Gemma model and tokenizer with fallback strategy."""
logger.info("=" * 60)
logger.info("🚀 INICIANDO CARGA DEL MODELO GEMMA")
logger.info("=" * 60)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
logger.info("✅ HF_TOKEN encontrado en variables de entorno")
else:
logger.warning("⚠️ HF_TOKEN no encontrado en variables de entorno")
logger.info(" (El modelo debe ser público o tener HF_TOKEN configurado)")
loaded = False
last_error = None
for i, model_variant in enumerate(MODEL_VARIANTS):
if i > 0:
logger.info(f"🔄 Intentando fallback: {model_variant}")
try:
logger.info(f"📥 [1/4] Descargando tokenizer de: {model_variant}")
download_start = time.time()
self.tokenizer = AutoTokenizer.from_pretrained(
model_variant,
cache_dir=self.cache_dir,
token=hf_token,
trust_remote_code=True,
)
logger.info(f" ✅ Tokenizer descargado en {time.time() - download_start:.1f}s")
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info(" ✅ pad_token = eos_token (configurado)")
else:
logger.info(f" ✅ pad_token ya configurado: {self.tokenizer.pad_token}")
logger.info(f"📥 [2/4] Descargando modelo: {model_variant}")
logger.info(" ℹ️ Tamaño: ~5-8 GB, puede tomar varios minutos...")
model_start = time.time()
# Configuración especial para Gemma 4
if "gemma-4" in model_variant:
logger.info(" ℹ️ Configuración Gemma 4: float16, device_map=cpu")
self.model = AutoModelForCausalLM.from_pretrained(
model_variant,
device_map="cpu",
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True,
cache_dir=self.cache_dir,
token=hf_token,
)
else:
# Configuración original para Gemma 2
logger.info(" ℹ️ Configuración: device_map=cpu, torch_dtype=float32")
self.model = AutoModelForCausalLM.from_pretrained(
model_variant,
device_map="cpu",
torch_dtype=torch.float32,
cache_dir=self.cache_dir,
token=hf_token,
trust_remote_code=True,
)
model_time = time.time() - model_start
logger.info(f" ✅ Modelo descargado en {model_time:.1f}s")
logger.info(" ℹ️ Ejecutando model.eval()...")
self.model.eval()
self.model_name = model_variant
loaded = True
logger.info("=" * 60)
logger.info(f"✅ MODELO CARGADO EXITOSAMENTE: {self.model_name}")
logger.info(f" 📍 Device: CPU (float32)")
logger.info(f" 💾 Memoria aproximada: ~5-6 GB RAM")
logger.info(f" ⏱️ Tiempo total: {model_time:.1f}s")
logger.info("=" * 60)
break
except KeyError as e:
logger.error(f"❌ KeyError con {model_variant}: {e}")
logger.info(" → Intentando siguiente variante...")
last_error = e
continue
except TypeError as e:
if "timeout" in str(e):
logger.error(f"❌ Timeout con {model_variant}: {e}")
else:
logger.error(f"❌ TypeError con {model_variant}: {e}")
logger.info(" → Intentando siguiente variante...")
last_error = e
continue
except Exception as e:
logger.error(f"❌ Error cargando {model_variant}: {str(e)}")
logger.info(" → Intentando siguiente variante...")
last_error = e
continue
if not loaded:
error_msg = (
f"❌ Falló la carga de todas las variantes de Gemma. "
f"Último error: {last_error}. "
f"Modelos intentados: {MODEL_VARIANTS}. "
f"Asegúrate de tener transformers>={MIN_TRANSFORMERS_VERSION}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
def generate(
self,
prompt: str,
max_new_tokens: int = 50,
min_new_tokens: int = 10,
temperature: float = 0.3,
top_p: float = 0.85,
repetition_penalty: float = 1.1,
early_stopping: bool = True,
no_repeat_ngram_size: int = 3,
on_tokens_generated: Optional[Callable[[int, float], None]] = None,
) -> str:
"""Generate a response from a prompt.
Args:
prompt: The input prompt string in Gemma chat format.
max_new_tokens: Maximum number of tokens to generate.
min_new_tokens: Minimum number of tokens to generate.
temperature: Sampling temperature (higher = more random).
top_p: Nucleus sampling threshold.
repetition_penalty: Penalty for repeating tokens (1.0 = no penalty).
early_stopping: Whether to stop when reaching end of sentence.
no_repeat_ngram_size: Prevents repeating n-grams of this size.
on_tokens_generated: Callback to report tokens and elapsed time.
Returns:
Generated response string (without the prompt).
"""
start_time = time.time()
try:
logger.info(f"Generating response for prompt (length: {len(prompt)})")
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=40,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
early_stopping=early_stopping,
)
generated_text = self.tokenizer.decode(
outputs[0],
skip_special_tokens=True,
)
response = generated_text[len(prompt):].strip()
response = self._clean_response(response)
response = self.fix_common_errors(response)
if len(response) < 10:
logger.warning(f"Response very short ({len(response)} chars)")
return response.strip() if response.strip() else "No se pudo generar una respuesta."
elapsed = time.time() - start_time
tokens_generated = len(outputs[0]) - len(inputs["input_ids"][0])
logger.info(
f"Generated {tokens_generated} tokens in {elapsed:.2f}s "
f"({tokens_generated/elapsed:.1f} tokens/s)"
)
if on_tokens_generated:
on_tokens_generated(tokens_generated, elapsed)
self._clear_cache()
return response
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
self._clear_cache()
return "Lo siento, hubo un problema al generar la respuesta. Por favor, intenta de nuevo."
def generate_with_context(
self,
context: str,
question: str,
max_new_tokens: int = 60,
on_tokens_generated: Optional[Callable[[int, float], None]] = None,
) -> str:
"""Generate a response given context and a question (RAG mode).
Args:
context: Retrieved context from the RAG system.
question: User question.
max_new_tokens: Maximum tokens to generate.
on_tokens_generated: Callback(token_count, elapsed_seconds).
Returns:
Generated response based on the context.
"""
prompt = self._build_simple_prompt(context, question)
logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
return self.generate(
prompt=prompt,
max_new_tokens=100, # Respuestas más largas
min_new_tokens=15,
temperature=0.2,
top_p=0.85,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
on_tokens_generated=on_tokens_generated,
)
def _build_simple_prompt(self, context: str, question: str) -> str:
"""Build a prompt for Gemma optimized for RAG with short responses."""
PROMPT_TEMPLATE = """Responde la pregunta usando SOLO el texto que está entre === CONTEXTO === y === FIN CONTEXTO ===.
=== CONTEXTO ===
{context}
=== FIN CONTEXTO ===
PREGUNTA: {question}
REGLAS:
1. Responde con UNA sola frase corta
2. Empieza con "Sí" o "No" si es pregunta de sí/no
3. Si la respuesta no está en el contexto, di: "No encontré esa información"
RESPUESTA:"""
prompt = PROMPT_TEMPLATE.format(context=context, question=question)
return f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
def _clean_response(self, text: str) -> str:
if not text:
return text
import re
text = re.sub(r'^[^a-zA-ZáéíóúÁÉÍÓÚ¿¡]+', '', text)
words = text.split()
if len(words) > 1 and len(words[0]) <= 2:
text = ' '.join(words[1:])
text = re.sub(r'\s+', ' ', text).strip()
if text and text[0].islower():
text = text[0].upper() + text[1:]
return text
def fix_common_errors(self, text: str) -> str:
replacements = {
"constatancia": "constancia",
"constatancoa": "constancia",
"secondary": "secundaria",
"otografía": "fotografía",
"credenciación": "credencial",
"cartascompromiso": "carta compromiso",
"carta compromiso ": "carta compromiso ",
}
for wrong, correct in replacements.items():
text = text.replace(wrong, correct)
return text
def _clear_cache(self) -> None:
"""Clear Python and PyTorch garbage and cache."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("Cleared memory cache")
def get_model_info(self) -> dict:
"""Get information about the loaded model.
Returns:
Dictionary with model metadata.
"""
return {
"model_name": self.model_name,
"device": self.device,
"dtype": "float32",
"parameters": "2B",
"quantization": "none",
}