NoeMartinezSanchez commited on
Commit
77fcd76
·
1 Parent(s): b423380

Mejora de promt

Browse files
Files changed (1) hide show
  1. models/gemma_wrapper.py +24 -41
models/gemma_wrapper.py CHANGED
@@ -201,13 +201,14 @@ class GemmaWrapper:
201
  with torch.no_grad():
202
  outputs = self.model.generate(
203
  **inputs,
204
- max_new_tokens=max_new_tokens,
205
- min_new_tokens=min_new_tokens,
206
  temperature=temperature,
207
  top_p=top_p,
 
 
 
208
  repetition_penalty=repetition_penalty,
209
  no_repeat_ngram_size=no_repeat_ngram_size,
210
- do_sample=False,
211
  pad_token_id=self.tokenizer.pad_token_id,
212
  eos_token_id=self.tokenizer.eos_token_id,
213
  early_stopping=early_stopping,
@@ -264,34 +265,35 @@ question: str,
264
  logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
265
  return self.generate(
266
  prompt=prompt,
267
- max_new_tokens=400,
268
- min_new_tokens=20,
269
- temperature=0.2,
270
  top_p=0.85,
271
- repetition_penalty=1.15,
272
  no_repeat_ngram_size=3,
273
  )
274
 
275
  def _build_simple_prompt(self, context: str, question: str) -> str:
276
- prompt = f"""Eres un asistente de Prepa en Línea SEP. Responde usando EXACTAMENTE la información de los siguientes fragmentos.
277
-
278
- REGLAS IMPORTANTES:
279
- 1. SOLO usa información que aparezca TEXTUALMENTE en los fragmentos
280
- 2. Si la información no está en los fragmentos, di "No encontré esa información en los documentos"
281
- 3. NO inventes, NO resumas, NO agregues información
282
- 4. Puedes copiar textualmente las listas de documentos
283
-
284
- FRAGMENTOS:
 
 
 
285
  {context}
286
 
287
- PREGUNTA: {question}
288
 
289
- RESPUESTA (SOLO con información de los fragmentos):"""
290
 
291
- prompt = f"""<start_of_turn>user
292
- {prompt}<end_of_turn>
293
- <start_of_turn>model
294
- """
295
  return prompt
296
 
297
  def _clean_response(self, text: str) -> str:
@@ -340,25 +342,6 @@ RESPUESTA (SOLO con información de los fragmentos):"""
340
  torch.cuda.empty_cache()
341
  logger.debug("Cleared memory cache")
342
 
343
- def get_model_info(self) -> dict:
344
- """Get information about the loaded model.
345
-
346
- Returns:
347
- Dictionary with model metadata.
348
- """
349
- return {
350
- "model_name": self.model_name,
351
- "device": self.device,
352
- "dtype": "float32",
353
- "parameters": "2B",
354
- "quantization": "none",
355
- }
356
- """Clear Python and PyTorch garbage and cache."""
357
- gc.collect()
358
- if torch.cuda.is_available():
359
- torch.cuda.empty_cache()
360
- logger.debug("Cleared memory cache")
361
-
362
  def get_model_info(self) -> dict:
363
  """Get information about the loaded model.
364
 
 
201
  with torch.no_grad():
202
  outputs = self.model.generate(
203
  **inputs,
204
+ do_sample=True,
 
205
  temperature=temperature,
206
  top_p=top_p,
207
+ top_k=40,
208
+ max_new_tokens=max_new_tokens,
209
+ min_new_tokens=min_new_tokens,
210
  repetition_penalty=repetition_penalty,
211
  no_repeat_ngram_size=no_repeat_ngram_size,
 
212
  pad_token_id=self.tokenizer.pad_token_id,
213
  eos_token_id=self.tokenizer.eos_token_id,
214
  early_stopping=early_stopping,
 
265
  logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
266
  return self.generate(
267
  prompt=prompt,
268
+ max_new_tokens=200,
269
+ min_new_tokens=15,
270
+ temperature=0.3,
271
  top_p=0.85,
272
+ repetition_penalty=1.1,
273
  no_repeat_ngram_size=3,
274
  )
275
 
276
  def _build_simple_prompt(self, context: str, question: str) -> str:
277
+ """Build a prompt for Gemma following its exact expected format."""
278
+
279
+ system_message = """Eres un asistente de Prepa en Línea SEP.
280
+
281
+ REGLAS ESTRICTAS:
282
+ 1. Responde SOLO usando la información de los FRAGMENTOS que se te proporcionan
283
+ 2. Si los fragmentos NO contienen la respuesta, responde: "No encontré esa información en los documentos disponibles"
284
+ 3. NO inventes información
285
+ 4. NO uses conocimiento externo
286
+ 5. Responde en español, de forma clara y directa"""
287
+
288
+ user_message = f"""FRAGMENTOS DE LA CONVOCATORIA:
289
  {context}
290
 
291
+ PREGUNTA DEL USUARIO: {question}
292
 
293
+ RESPUESTA (basada ESTRICTAMENTE en los fragmentos):"""
294
 
295
+ prompt = f"<start_of_turn>user\n{system_message}\n\n{user_message}<end_of_turn>\n<start_of_turn>model\n"
296
+
 
 
297
  return prompt
298
 
299
  def _clean_response(self, text: str) -> str:
 
342
  torch.cuda.empty_cache()
343
  logger.debug("Cleared memory cache")
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  def get_model_info(self) -> dict:
346
  """Get information about the loaded model.
347