File size: 14,378 Bytes
d6d9ec6
4fa1e6b
d6d9ec6
 
4fa1e6b
 
 
8d7fb25
26ec9f1
4fa1e6b
c68dc2b
4fa1e6b
 
26ec9f1
4fa1e6b
1ea2a88
4fa1e6b
1ea2a88
4fa1e6b
 
26ec9f1
bd96ff5
 
 
26ec9f1
 
bd96ff5
26ec9f1
 
 
bd96ff5
26ec9f1
 
 
 
 
 
 
bd96ff5
2041232
bd96ff5
26ec9f1
 
bd96ff5
26ec9f1
 
4fa1e6b
bd96ff5
4fa1e6b
bd96ff5
 
4fa1e6b
 
 
 
 
 
 
 
 
 
 
 
 
26ec9f1
 
4fa1e6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26ec9f1
2041232
 
 
 
26ec9f1
 
2041232
26ec9f1
2041232
 
26ec9f1
 
 
 
e80ce4c
 
2041232
26ec9f1
 
2041232
e80ce4c
26ec9f1
 
 
 
 
 
2041232
26ec9f1
 
 
2041232
 
 
26ec9f1
2041232
d5d410a
26ec9f1
 
d5d410a
 
3165153
d5d410a
 
3165153
d5d410a
3165153
d5d410a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26ec9f1
 
2041232
 
 
 
26ec9f1
 
 
2041232
 
 
 
 
 
 
26ec9f1
 
 
2041232
 
26ec9f1
 
e80ce4c
 
2041232
 
 
 
e80ce4c
 
26ec9f1
2041232
 
26ec9f1
 
 
 
 
2041232
 
 
 
4fa1e6b
26ec9f1
 
4fa1e6b
 
 
 
daf93a4
 
 
 
4fa1e6b
 
 
c68dc2b
4fa1e6b
 
 
 
 
 
 
 
 
 
 
 
c68dc2b
4fa1e6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77fcd76
4fa1e6b
 
77fcd76
 
 
4fa1e6b
 
 
 
 
 
 
 
 
 
 
 
 
 
799e578
39758ea
799e578
4fa1e6b
 
 
 
 
 
 
 
 
 
c68dc2b
 
 
 
4fa1e6b
 
 
 
 
 
 
 
 
c380cb1
4fa1e6b
 
c68dc2b
daf93a4
c68dc2b
4fa1e6b
 
 
 
 
 
 
c68dc2b
4fa1e6b
 
 
 
1bcb90a
4fa1e6b
 
1bcb90a
4fa1e6b
bd96ff5
 
83a6605
43365fd
77fcd76
43365fd
c68dc2b
4fa1e6b
 
1bcb90a
daf93a4
77fcd76
3777c6e
37b34f7
3777c6e
8c5e054
3777c6e
507c852
83a6605
39758ea
3777c6e
 
 
 
 
83a6605
daf93a4
3777c6e
 
daf93a4
4fa1e6b
43365fd
 
 
799e578
 
39758ea
 
 
3f9a6aa
 
 
 
43365fd
 
 
 
 
 
799e578
39758ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fa1e6b
 
 
 
 
 
 
43365fd
 
fbe249f
 
43365fd
fbe249f
43365fd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
"""Phi-2 Wrapper for RAG-based chatbot.

This module provides a wrapper around the microsoft/phi-2 model for generating
responses in a RAG architecture, optimized for CPU-only inference.
"""

import gc
import os
import sys
import time
from typing import Optional, Callable

import torch
import transformers
from loguru import logger
from requests.adapters import HTTPAdapter
from transformers import AutoModelForCausalLM, AutoTokenizer
from urllib3.util.retry import Retry


MODEL_VARIANTS = [
    "google/gemma-4-e4b-it",      # 🚀 Principal: rápido y potente
    "google/gemma-4-e2b-it",      # ⚡ Backup: más ligero
    "google/gemma-2-2b-it",      # 🔄 Fallback final
]

MIN_TRANSFORMERS_VERSION = "4.51.0"


def _check_transformers_version():
    """Verify transformers version meets minimum requirement for Gemma 4."""
    installed = transformers.__version__
    major, minor, _ = installed.split(".")[:3]
    required_major, required_minor = MIN_TRANSFORMERS_VERSION.split(".")[:2]
    
    if (int(major) < int(required_major) or 
        (int(major) == int(required_major) and int(minor) < int(required_minor))):
        logger.warning(
            f"⚠️ transformers {installed} incompatible con gemma-4. "
            f"Requiere >= {MIN_TRANSFORMERS_VERSION}. "
            f"Se usará fallback gemma-2-2b-it"
        )
    else:
        logger.info(f"✅ transformers version: {installed} (compatible con gemma-4)")


class GemmaWrapper:
    """Wrapper for Gemma-4 E4B model with RAG integration support.

    This class provides an interface to the Gemma-4-e4b-it model optimized for
    CPU-only inference.
    """

    def __init__(
        self,
        model_name: str = "google/gemma-2-2b-it",
        cache_dir: str = "models/cache",
    ) -> None:
        """Initialize the Gemma wrapper.

        Args:
            model_name: Hugging Face model identifier.
            cache_dir: Directory to cache the model files.
        """
        _check_transformers_version()
        
        self.model_name = model_name
        self.cache_dir = cache_dir
        self.device = "cpu"
        self.model = None
        self.tokenizer = None

        self._setup_logger()
        self._load_model()

    def _setup_logger(self) -> None:
        """Configure logging for the wrapper."""
        logger.add(
            "logs/gemma_wrapper.log",
            rotation="10 MB",
            retention="7 days",
            level="INFO",
        )

    def _load_model(self) -> None:
        """Load the Gemma model and tokenizer with fallback strategy."""
        logger.info("=" * 60)
        logger.info("🚀 INICIANDO CARGA DEL MODELO GEMMA")
        logger.info("=" * 60)
        
        hf_token = os.getenv("HF_TOKEN")
        if hf_token:
            logger.info("✅ HF_TOKEN encontrado en variables de entorno")
        else:
            logger.warning("⚠️ HF_TOKEN no encontrado en variables de entorno")
            logger.info("   (El modelo debe ser público o tener HF_TOKEN configurado)")

        loaded = False
        last_error = None
        
        for i, model_variant in enumerate(MODEL_VARIANTS):
            if i > 0:
                logger.info(f"🔄 Intentando fallback: {model_variant}")
            
            try:
                logger.info(f"📥 [1/4] Descargando tokenizer de: {model_variant}")
                download_start = time.time()
                self.tokenizer = AutoTokenizer.from_pretrained(
                    model_variant,
                    cache_dir=self.cache_dir,
                    token=hf_token,
                    trust_remote_code=True,
                )
                logger.info(f"   ✅ Tokenizer descargado en {time.time() - download_start:.1f}s")
                
                if self.tokenizer.pad_token is None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                    logger.info("   ✅ pad_token = eos_token (configurado)")
                else:
                    logger.info(f"   ✅ pad_token ya configurado: {self.tokenizer.pad_token}")

                logger.info(f"📥 [2/4] Descargando modelo: {model_variant}")
                logger.info("   ℹ️ Tamaño: ~5-8 GB, puede tomar varios minutos...")
                model_start = time.time()
                
                # Configuración especial para Gemma 4
                if "gemma-4" in model_variant:
                    logger.info("   ℹ️ Configuración Gemma 4: float16, device_map=cpu")
                    self.model = AutoModelForCausalLM.from_pretrained(
                        model_variant,
                        device_map="cpu",
                        torch_dtype=torch.float16,
                        trust_remote_code=True,
                        low_cpu_mem_usage=True,
                        cache_dir=self.cache_dir,
                        token=hf_token,
                    )
                else:
                    # Configuración original para Gemma 2
                    logger.info("   ℹ️ Configuración: device_map=cpu, torch_dtype=float32")
                    self.model = AutoModelForCausalLM.from_pretrained(
                        model_variant,
                        device_map="cpu",
                        torch_dtype=torch.float32,
                        cache_dir=self.cache_dir,
                        token=hf_token,
                        trust_remote_code=True,
                    )
                
                model_time = time.time() - model_start
                logger.info(f"   ✅ Modelo descargado en {model_time:.1f}s")
                
                logger.info("   ℹ️ Ejecutando model.eval()...")
                self.model.eval()
                
                self.model_name = model_variant
                loaded = True
                
                logger.info("=" * 60)
                logger.info(f"✅ MODELO CARGADO EXITOSAMENTE: {self.model_name}")
                logger.info(f"   📍 Device: CPU (float32)")
                logger.info(f"   💾 Memoria aproximada: ~5-6 GB RAM")
                logger.info(f"   ⏱️ Tiempo total: {model_time:.1f}s")
                logger.info("=" * 60)
                break
                
            except KeyError as e:
                logger.error(f"❌ KeyError con {model_variant}: {e}")
                logger.info("   → Intentando siguiente variante...")
                last_error = e
                continue
            except TypeError as e:
                if "timeout" in str(e):
                    logger.error(f"❌ Timeout con {model_variant}: {e}")
                else:
                    logger.error(f"❌ TypeError con {model_variant}: {e}")
                logger.info("   → Intentando siguiente variante...")
                last_error = e
                continue
            except Exception as e:
                logger.error(f"❌ Error cargando {model_variant}: {str(e)}")
                logger.info("   → Intentando siguiente variante...")
                last_error = e
                continue

        if not loaded:
            error_msg = (
                f"❌ Falló la carga de todas las variantes de Gemma. "
                f"Último error: {last_error}. "
                f"Modelos intentados: {MODEL_VARIANTS}. "
                f"Asegúrate de tener transformers>={MIN_TRANSFORMERS_VERSION}"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)

    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 50,
        min_new_tokens: int = 10,
        temperature: float = 0.3,
        top_p: float = 0.85,
        repetition_penalty: float = 1.1,
        early_stopping: bool = True,
        no_repeat_ngram_size: int = 3,
        on_tokens_generated: Optional[Callable[[int, float], None]] = None,
    ) -> str:
        """Generate a response from a prompt.

        Args:
            prompt: The input prompt string in Gemma chat format.
            max_new_tokens: Maximum number of tokens to generate.
            min_new_tokens: Minimum number of tokens to generate.
            temperature: Sampling temperature (higher = more random).
            top_p: Nucleus sampling threshold.
            repetition_penalty: Penalty for repeating tokens (1.0 = no penalty).
            early_stopping: Whether to stop when reaching end of sentence.
            no_repeat_ngram_size: Prevents repeating n-grams of this size.
            on_tokens_generated: Callback to report tokens and elapsed time.

        Returns:
            Generated response string (without the prompt).
        """
        start_time = time.time()

        try:
            logger.info(f"Generating response for prompt (length: {len(prompt)})")

            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=1024,
            )

            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=40,
                    max_new_tokens=max_new_tokens,
                    min_new_tokens=min_new_tokens,
                    repetition_penalty=repetition_penalty,
                    no_repeat_ngram_size=no_repeat_ngram_size,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    early_stopping=early_stopping,
                )

            generated_text = self.tokenizer.decode(
                outputs[0],
                skip_special_tokens=True,
            )

            response = generated_text[len(prompt):].strip()

            response = self._clean_response(response)
            response = self.fix_common_errors(response)

            if len(response) < 10:
                logger.warning(f"Response very short ({len(response)} chars)")
                return response.strip() if response.strip() else "No se pudo generar una respuesta."

            elapsed = time.time() - start_time
            tokens_generated = len(outputs[0]) - len(inputs["input_ids"][0])
            logger.info(
                f"Generated {tokens_generated} tokens in {elapsed:.2f}s "
                f"({tokens_generated/elapsed:.1f} tokens/s)"
            )
            
            if on_tokens_generated:
                on_tokens_generated(tokens_generated, elapsed)
            
            self._clear_cache()

            return response

        except Exception as e:
            logger.error(f"Generation failed: {str(e)}")
            self._clear_cache()
            return "Lo siento, hubo un problema al generar la respuesta. Por favor, intenta de nuevo."

    def generate_with_context(
        self,
        context: str,
        question: str,
        max_new_tokens: int = 60,
        on_tokens_generated: Optional[Callable[[int, float], None]] = None,
    ) -> str:
        """Generate a response given context and a question (RAG mode).

        Args:
            context: Retrieved context from the RAG system.
            question: User question.
            max_new_tokens: Maximum tokens to generate.
            on_tokens_generated: Callback(token_count, elapsed_seconds).

        Returns:
            Generated response based on the context.
        """
        prompt = self._build_simple_prompt(context, question)

        logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
        return self.generate(
            prompt=prompt,
            max_new_tokens=100,  # Respuestas más largas
            min_new_tokens=15,
            temperature=0.2,
            top_p=0.85,
            repetition_penalty=1.1,
            no_repeat_ngram_size=3,
            on_tokens_generated=on_tokens_generated,
        )

    def _build_simple_prompt(self, context: str, question: str) -> str:
        """Build a prompt for Gemma optimized for RAG with short responses."""
        
        PROMPT_TEMPLATE = """Responde la pregunta usando SOLO el texto que está entre === CONTEXTO === y === FIN CONTEXTO ===.

=== CONTEXTO ===
{context}
=== FIN CONTEXTO ===

PREGUNTA: {question}

REGLAS:
1. Responde con UNA sola frase corta
2. Empieza con "Sí" o "No" si es pregunta de sí/no
3. Si la respuesta no está en el contexto, di: "No encontré esa información"

RESPUESTA:"""
        
        prompt = PROMPT_TEMPLATE.format(context=context, question=question)
        
        return f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

    def _clean_response(self, text: str) -> str:
        if not text:
            return text

        import re

        text = re.sub(r'^[^a-zA-ZáéíóúÁÉÍÓÚ¿¡]+', '', text)

        words = text.split()
        if len(words) > 1 and len(words[0]) <= 2:
            text = ' '.join(words[1:])

        text = re.sub(r'\s+', ' ', text).strip()

        if text and text[0].islower():
            text = text[0].upper() + text[1:]

        return text

    def fix_common_errors(self, text: str) -> str:
        replacements = {
            "constatancia": "constancia",
            "constatancoa": "constancia",
            "secondary": "secundaria",
            "otografía": "fotografía",
            "credenciación": "credencial",
            "cartascompromiso": "carta compromiso",
            "carta compromiso ": "carta compromiso ",
        }

        for wrong, correct in replacements.items():
            text = text.replace(wrong, correct)

        return text

    def _clear_cache(self) -> None:
        """Clear Python and PyTorch garbage and cache."""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        logger.debug("Cleared memory cache")

    def get_model_info(self) -> dict:
        """Get information about the loaded model.

        Returns:
            Dictionary with model metadata.
        """
        return {
            "model_name": self.model_name,
            "device": self.device,
            "dtype": "float32",
            "parameters": "2B",
            "quantization": "none",
        }