Spaces:
Paused
Paused
Update rag_system.py
Browse files- rag_system.py +22 -29
rag_system.py
CHANGED
|
@@ -138,42 +138,28 @@ class RAGLLMSystem:
|
|
| 138 |
self,
|
| 139 |
query: str,
|
| 140 |
context_docs: List[Dict],
|
| 141 |
-
max_new_tokens: int =
|
| 142 |
temperature: float = 0.7,
|
| 143 |
top_p: float = 0.9
|
| 144 |
) -> str:
|
| 145 |
"""Generar respuesta con Salamandra."""
|
| 146 |
|
| 147 |
-
# Construir contexto
|
| 148 |
context_text = "\n\n---\n\n".join([
|
| 149 |
-
f"[
|
| 150 |
-
for doc in context_docs
|
| 151 |
])
|
| 152 |
|
| 153 |
-
# Prompt
|
| 154 |
-
prompt = f"""Eres ALIA,
|
| 155 |
-
|
| 156 |
-
Tu funcion es ayudar a funcionarios publicos, tecnicos de turismo y responsables de destinos turisticos a:
|
| 157 |
-
- Comprender y aplicar estrategias de planes turisticos
|
| 158 |
-
- Obtener informacion sobre mejores practicas en turismo sostenible
|
| 159 |
-
- Consultar casos de exito de otros municipios
|
| 160 |
-
- Disenar e implementar planes estrategicos turisticos
|
| 161 |
-
|
| 162 |
-
INSTRUCCIONES:
|
| 163 |
-
1. Responde SIEMPRE basandote en los documentos proporcionados
|
| 164 |
-
2. Si la informacion no esta en los documentos, indica claramente que no la tienes
|
| 165 |
-
3. Cita los documentos fuente cuando sea relevante
|
| 166 |
-
4. Usa un tono profesional pero accesible
|
| 167 |
-
5. Estructura tus respuestas de forma clara con bullets o numeracion cuando sea apropiado
|
| 168 |
|
| 169 |
-
|
| 170 |
|
| 171 |
{context_text}
|
| 172 |
|
| 173 |
-
PREGUNTA
|
| 174 |
-
{query}
|
| 175 |
|
| 176 |
-
RESPUESTA:"""
|
| 177 |
|
| 178 |
# Tokenizar
|
| 179 |
inputs = self.tokenizer(
|
|
@@ -187,29 +173,36 @@ RESPUESTA:"""
|
|
| 187 |
if self.device == 'cuda':
|
| 188 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 189 |
|
| 190 |
-
# Generar
|
| 191 |
try:
|
|
|
|
|
|
|
| 192 |
with torch.no_grad():
|
| 193 |
outputs = self.llm_model.generate(
|
| 194 |
**inputs,
|
| 195 |
-
max_new_tokens=max_new_tokens,
|
| 196 |
temperature=temperature,
|
| 197 |
top_p=top_p,
|
| 198 |
do_sample=True,
|
|
|
|
| 199 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 200 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 201 |
)
|
| 202 |
|
|
|
|
|
|
|
| 203 |
# Decodificar
|
| 204 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 205 |
|
| 206 |
-
# Extraer respuesta
|
| 207 |
-
if "RESPUESTA
|
| 208 |
-
response = response.split("RESPUESTA
|
|
|
|
| 209 |
|
| 210 |
-
return response
|
| 211 |
|
| 212 |
except Exception as e:
|
|
|
|
| 213 |
return f"Error generando respuesta: {str(e)}"
|
| 214 |
|
| 215 |
def query(
|
|
|
|
| 138 |
self,
|
| 139 |
query: str,
|
| 140 |
context_docs: List[Dict],
|
| 141 |
+
max_new_tokens: int = 512,
|
| 142 |
temperature: float = 0.7,
|
| 143 |
top_p: float = 0.9
|
| 144 |
) -> str:
|
| 145 |
"""Generar respuesta con Salamandra."""
|
| 146 |
|
| 147 |
+
# Construir contexto (limitado para evitar timeouts)
|
| 148 |
context_text = "\n\n---\n\n".join([
|
| 149 |
+
f"[Doc: {doc['filename'][:30]}]\n{doc['content'][:1000]}"
|
| 150 |
+
for doc in context_docs[:3] # Solo top 3 docs
|
| 151 |
])
|
| 152 |
|
| 153 |
+
# Prompt optimizado (más corto)
|
| 154 |
+
prompt = f"""Eres ALIA, asistente de turismo de la Comunidad Valenciana.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
Responde basandote en estos documentos:
|
| 157 |
|
| 158 |
{context_text}
|
| 159 |
|
| 160 |
+
PREGUNTA: {query}
|
|
|
|
| 161 |
|
| 162 |
+
RESPUESTA (sé conciso):"""
|
| 163 |
|
| 164 |
# Tokenizar
|
| 165 |
inputs = self.tokenizer(
|
|
|
|
| 173 |
if self.device == 'cuda':
|
| 174 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 175 |
|
| 176 |
+
# Generar con parametros optimizados
|
| 177 |
try:
|
| 178 |
+
print(f"[GENERATE] Iniciando generacion en {self.device}...")
|
| 179 |
+
|
| 180 |
with torch.no_grad():
|
| 181 |
outputs = self.llm_model.generate(
|
| 182 |
**inputs,
|
| 183 |
+
max_new_tokens=min(max_new_tokens, 256), # Limitar a 256 tokens max
|
| 184 |
temperature=temperature,
|
| 185 |
top_p=top_p,
|
| 186 |
do_sample=True,
|
| 187 |
+
num_beams=1, # Greedy decoding para velocidad
|
| 188 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 189 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 190 |
)
|
| 191 |
|
| 192 |
+
print(f"[GENERATE] Generacion completada")
|
| 193 |
+
|
| 194 |
# Decodificar
|
| 195 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 196 |
|
| 197 |
+
# Extraer solo la respuesta generada
|
| 198 |
+
if "RESPUESTA" in response:
|
| 199 |
+
response = response.split("RESPUESTA")[-1].strip()
|
| 200 |
+
response = response.replace("(sé conciso):", "").strip()
|
| 201 |
|
| 202 |
+
return response[:2000] # Limitar largo de respuesta
|
| 203 |
|
| 204 |
except Exception as e:
|
| 205 |
+
print(f"[ERROR] Error en generacion: {str(e)}")
|
| 206 |
return f"Error generando respuesta: {str(e)}"
|
| 207 |
|
| 208 |
def query(
|