Spaces:
Running
Running
NoeMartinezSanchez commited on
Commit ·
079adc2
1
Parent(s): c71b1c7
min_new_tokens=0 como parámetro, temperatura=0.7
Browse files- models/tinyllama_wrapper.py +17 -5
models/tinyllama_wrapper.py
CHANGED
|
@@ -109,8 +109,9 @@ class TinyLlamaWrapper:
|
|
| 109 |
self,
|
| 110 |
prompt: str,
|
| 111 |
max_new_tokens: int = 180,
|
| 112 |
-
|
| 113 |
-
|
|
|
|
| 114 |
repetition_penalty: float = 1.1,
|
| 115 |
early_stopping: bool = False,
|
| 116 |
no_repeat_ngram_size: int = 3,
|
|
@@ -120,6 +121,7 @@ class TinyLlamaWrapper:
|
|
| 120 |
Args:
|
| 121 |
prompt: The input prompt string.
|
| 122 |
max_new_tokens: Maximum number of tokens to generate.
|
|
|
|
| 123 |
temperature: Sampling temperature (higher = more random).
|
| 124 |
top_p: Nucleus sampling threshold.
|
| 125 |
repetition_penalty: Penalty for repeating tokens (1.0 = no penalty).
|
|
@@ -148,6 +150,7 @@ class TinyLlamaWrapper:
|
|
| 148 |
outputs = self.model.generate(
|
| 149 |
**inputs,
|
| 150 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 151 |
temperature=temperature,
|
| 152 |
top_p=top_p,
|
| 153 |
repetition_penalty=repetition_penalty,
|
|
@@ -165,6 +168,10 @@ class TinyLlamaWrapper:
|
|
| 165 |
|
| 166 |
response = generated_text[len(prompt):].strip()
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
elapsed = time.time() - start_time
|
| 169 |
tokens_generated = len(outputs[0]) - len(inputs["input_ids"][0])
|
| 170 |
logger.info(
|
|
@@ -198,6 +205,10 @@ class TinyLlamaWrapper:
|
|
| 198 |
|
| 199 |
prompt = f"""<|system|>
|
| 200 |
Eres un asistente virtual experto de Prepa en Línea SEP. Tu función es responder preguntas basándote ESTRICTAMENTE en el contexto oficial proporcionado. Siempre respondes en español neutro, de forma clara, útil y profesional. Si la respuesta no está en el contexto, dices exactamente: "Lo siento, no encontré información específica sobre eso en los materiales disponibles. Por favor, consulta los canales oficiales."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
<|end|>
|
| 202 |
<|user|>
|
| 203 |
Contexto oficial:
|
|
@@ -207,17 +218,18 @@ Pregunta del estudiante:
|
|
| 207 |
{question}
|
| 208 |
<|end|>
|
| 209 |
<|assistant|>
|
| 210 |
-
De acuerdo a la información oficial
|
| 211 |
|
| 212 |
logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
|
| 213 |
return self.generate(
|
| 214 |
prompt,
|
| 215 |
max_new_tokens=max_new_tokens,
|
| 216 |
-
temperature=0.
|
| 217 |
-
top_p=0.
|
| 218 |
repetition_penalty=1.1,
|
| 219 |
no_repeat_ngram_size=3,
|
| 220 |
early_stopping=False,
|
|
|
|
| 221 |
)
|
| 222 |
|
| 223 |
def _log_error(self, error_msg: str) -> None:
|
|
|
|
| 109 |
self,
|
| 110 |
prompt: str,
|
| 111 |
max_new_tokens: int = 180,
|
| 112 |
+
min_new_tokens: int = 0,
|
| 113 |
+
temperature: float = 0.7,
|
| 114 |
+
top_p: float = 0.9,
|
| 115 |
repetition_penalty: float = 1.1,
|
| 116 |
early_stopping: bool = False,
|
| 117 |
no_repeat_ngram_size: int = 3,
|
|
|
|
| 121 |
Args:
|
| 122 |
prompt: The input prompt string.
|
| 123 |
max_new_tokens: Maximum number of tokens to generate.
|
| 124 |
+
min_new_tokens: Minimum number of tokens to generate (forces at least this many).
|
| 125 |
temperature: Sampling temperature (higher = more random).
|
| 126 |
top_p: Nucleus sampling threshold.
|
| 127 |
repetition_penalty: Penalty for repeating tokens (1.0 = no penalty).
|
|
|
|
| 150 |
outputs = self.model.generate(
|
| 151 |
**inputs,
|
| 152 |
max_new_tokens=max_new_tokens,
|
| 153 |
+
min_new_tokens=min_new_tokens,
|
| 154 |
temperature=temperature,
|
| 155 |
top_p=top_p,
|
| 156 |
repetition_penalty=repetition_penalty,
|
|
|
|
| 168 |
|
| 169 |
response = generated_text[len(prompt):].strip()
|
| 170 |
|
| 171 |
+
if len(response) < 10:
|
| 172 |
+
logger.warning(f"Response too short ({len(response)} chars), using fallback")
|
| 173 |
+
response = "Lo siento, no pude generar una respuesta específica. Por favor, reformula tu pregunta o consulta los materiales oficiales."
|
| 174 |
+
|
| 175 |
elapsed = time.time() - start_time
|
| 176 |
tokens_generated = len(outputs[0]) - len(inputs["input_ids"][0])
|
| 177 |
logger.info(
|
|
|
|
| 205 |
|
| 206 |
prompt = f"""<|system|>
|
| 207 |
Eres un asistente virtual experto de Prepa en Línea SEP. Tu función es responder preguntas basándote ESTRICTAMENTE en el contexto oficial proporcionado. Siempre respondes en español neutro, de forma clara, útil y profesional. Si la respuesta no está en el contexto, dices exactamente: "Lo siento, no encontré información específica sobre eso en los materiales disponibles. Por favor, consulta los canales oficiales."
|
| 208 |
+
|
| 209 |
+
Ejemplo de respuesta correcta:
|
| 210 |
+
Pregunta: ¿Qué pasa si no tengo mi certificado?
|
| 211 |
+
Respuesta: De acuerdo a la información oficial, tienes 6 meses para entregarlo. Durante la inscripción, deberás subir una carta compromiso y una constancia de estudios.
|
| 212 |
<|end|>
|
| 213 |
<|user|>
|
| 214 |
Contexto oficial:
|
|
|
|
| 218 |
{question}
|
| 219 |
<|end|>
|
| 220 |
<|assistant|>
|
| 221 |
+
De acuerdo a la información oficial,"""
|
| 222 |
|
| 223 |
logger.info(f"RAG generation - Context length: {len(context)}, Question: {question[:50]}...")
|
| 224 |
return self.generate(
|
| 225 |
prompt,
|
| 226 |
max_new_tokens=max_new_tokens,
|
| 227 |
+
temperature=0.7,
|
| 228 |
+
top_p=0.9,
|
| 229 |
repetition_penalty=1.1,
|
| 230 |
no_repeat_ngram_size=3,
|
| 231 |
early_stopping=False,
|
| 232 |
+
min_new_tokens=20,
|
| 233 |
)
|
| 234 |
|
| 235 |
def _log_error(self, error_msg: str) -> None:
|