Spaces:
Running
Running
NoeMartinezSanchez commited on
Commit ·
77fcd76
1
Parent(s): b423380
Mejora de promt
Browse files- models/gemma_wrapper.py +24 -41
models/gemma_wrapper.py
CHANGED
|
@@ -201,13 +201,14 @@ class GemmaWrapper:
|
|
| 201 |
with torch.no_grad():
|
| 202 |
outputs = self.model.generate(
|
| 203 |
**inputs,
|
| 204 |
-
|
| 205 |
-
min_new_tokens=min_new_tokens,
|
| 206 |
temperature=temperature,
|
| 207 |
top_p=top_p,
|
|
|
|
|
|
|
|
|
|
| 208 |
repetition_penalty=repetition_penalty,
|
| 209 |
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 210 |
-
do_sample=False,
|
| 211 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 212 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 213 |
early_stopping=early_stopping,
|
|
@@ -264,34 +265,35 @@ question: str,
|
|
| 264 |
logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
|
| 265 |
return self.generate(
|
| 266 |
prompt=prompt,
|
| 267 |
-
max_new_tokens=
|
| 268 |
-
min_new_tokens=
|
| 269 |
-
temperature=0.
|
| 270 |
top_p=0.85,
|
| 271 |
-
repetition_penalty=1.
|
| 272 |
no_repeat_ngram_size=3,
|
| 273 |
)
|
| 274 |
|
| 275 |
def _build_simple_prompt(self, context: str, question: str) -> str:
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
| 285 |
{context}
|
| 286 |
|
| 287 |
-
PREGUNTA: {question}
|
| 288 |
|
| 289 |
-
RESPUESTA (
|
| 290 |
|
| 291 |
-
prompt = f"
|
| 292 |
-
|
| 293 |
-
<start_of_turn>model
|
| 294 |
-
"""
|
| 295 |
return prompt
|
| 296 |
|
| 297 |
def _clean_response(self, text: str) -> str:
|
|
@@ -340,25 +342,6 @@ RESPUESTA (SOLO con información de los fragmentos):"""
|
|
| 340 |
torch.cuda.empty_cache()
|
| 341 |
logger.debug("Cleared memory cache")
|
| 342 |
|
| 343 |
-
def get_model_info(self) -> dict:
|
| 344 |
-
"""Get information about the loaded model.
|
| 345 |
-
|
| 346 |
-
Returns:
|
| 347 |
-
Dictionary with model metadata.
|
| 348 |
-
"""
|
| 349 |
-
return {
|
| 350 |
-
"model_name": self.model_name,
|
| 351 |
-
"device": self.device,
|
| 352 |
-
"dtype": "float32",
|
| 353 |
-
"parameters": "2B",
|
| 354 |
-
"quantization": "none",
|
| 355 |
-
}
|
| 356 |
-
"""Clear Python and PyTorch garbage and cache."""
|
| 357 |
-
gc.collect()
|
| 358 |
-
if torch.cuda.is_available():
|
| 359 |
-
torch.cuda.empty_cache()
|
| 360 |
-
logger.debug("Cleared memory cache")
|
| 361 |
-
|
| 362 |
def get_model_info(self) -> dict:
|
| 363 |
"""Get information about the loaded model.
|
| 364 |
|
|
|
|
| 201 |
with torch.no_grad():
|
| 202 |
outputs = self.model.generate(
|
| 203 |
**inputs,
|
| 204 |
+
do_sample=True,
|
|
|
|
| 205 |
temperature=temperature,
|
| 206 |
top_p=top_p,
|
| 207 |
+
top_k=40,
|
| 208 |
+
max_new_tokens=max_new_tokens,
|
| 209 |
+
min_new_tokens=min_new_tokens,
|
| 210 |
repetition_penalty=repetition_penalty,
|
| 211 |
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
|
|
| 212 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 213 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 214 |
early_stopping=early_stopping,
|
|
|
|
| 265 |
logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
|
| 266 |
return self.generate(
|
| 267 |
prompt=prompt,
|
| 268 |
+
max_new_tokens=200,
|
| 269 |
+
min_new_tokens=15,
|
| 270 |
+
temperature=0.3,
|
| 271 |
top_p=0.85,
|
| 272 |
+
repetition_penalty=1.1,
|
| 273 |
no_repeat_ngram_size=3,
|
| 274 |
)
|
| 275 |
|
| 276 |
def _build_simple_prompt(self, context: str, question: str) -> str:
|
| 277 |
+
"""Build a prompt for Gemma following its exact expected format."""
|
| 278 |
+
|
| 279 |
+
system_message = """Eres un asistente de Prepa en Línea SEP.
|
| 280 |
+
|
| 281 |
+
REGLAS ESTRICTAS:
|
| 282 |
+
1. Responde SOLO usando la información de los FRAGMENTOS que se te proporcionan
|
| 283 |
+
2. Si los fragmentos NO contienen la respuesta, responde: "No encontré esa información en los documentos disponibles"
|
| 284 |
+
3. NO inventes información
|
| 285 |
+
4. NO uses conocimiento externo
|
| 286 |
+
5. Responde en español, de forma clara y directa"""
|
| 287 |
+
|
| 288 |
+
user_message = f"""FRAGMENTOS DE LA CONVOCATORIA:
|
| 289 |
{context}
|
| 290 |
|
| 291 |
+
PREGUNTA DEL USUARIO: {question}
|
| 292 |
|
| 293 |
+
RESPUESTA (basada ESTRICTAMENTE en los fragmentos):"""
|
| 294 |
|
| 295 |
+
prompt = f"<start_of_turn>user\n{system_message}\n\n{user_message}<end_of_turn>\n<start_of_turn>model\n"
|
| 296 |
+
|
|
|
|
|
|
|
| 297 |
return prompt
|
| 298 |
|
| 299 |
def _clean_response(self, text: str) -> str:
|
|
|
|
| 342 |
torch.cuda.empty_cache()
|
| 343 |
logger.debug("Cleared memory cache")
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
def get_model_info(self) -> dict:
|
| 346 |
"""Get information about the loaded model.
|
| 347 |
|