perellorets commited on
Commit
a267084
·
verified ·
1 Parent(s): 2fcd754

Update rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +22 -29
rag_system.py CHANGED
@@ -138,42 +138,28 @@ class RAGLLMSystem:
138
  self,
139
  query: str,
140
  context_docs: List[Dict],
141
- max_new_tokens: int = 1024,
142
  temperature: float = 0.7,
143
  top_p: float = 0.9
144
  ) -> str:
145
  """Generar respuesta con Salamandra."""
146
 
147
- # Construir contexto
148
  context_text = "\n\n---\n\n".join([
149
- f"[Documento: {doc['filename']}]\n{doc['content'][:2000]}"
150
- for doc in context_docs
151
  ])
152
 
153
- # Prompt
154
- prompt = f"""Eres ALIA, un asistente experto en planificacion estrategica turistica de la Comunidad Valenciana.
155
-
156
- Tu funcion es ayudar a funcionarios publicos, tecnicos de turismo y responsables de destinos turisticos a:
157
- - Comprender y aplicar estrategias de planes turisticos
158
- - Obtener informacion sobre mejores practicas en turismo sostenible
159
- - Consultar casos de exito de otros municipios
160
- - Disenar e implementar planes estrategicos turisticos
161
-
162
- INSTRUCCIONES:
163
- 1. Responde SIEMPRE basandote en los documentos proporcionados
164
- 2. Si la informacion no esta en los documentos, indica claramente que no la tienes
165
- 3. Cita los documentos fuente cuando sea relevante
166
- 4. Usa un tono profesional pero accesible
167
- 5. Estructura tus respuestas de forma clara con bullets o numeracion cuando sea apropiado
168
 
169
- CONTEXTO (Documentos de planes estrategicos de turismo):
170
 
171
  {context_text}
172
 
173
- PREGUNTA DEL USUARIO:
174
- {query}
175
 
176
- RESPUESTA:"""
177
 
178
  # Tokenizar
179
  inputs = self.tokenizer(
@@ -187,29 +173,36 @@ RESPUESTA:"""
187
  if self.device == 'cuda':
188
  inputs = {k: v.cuda() for k, v in inputs.items()}
189
 
190
- # Generar
191
  try:
 
 
192
  with torch.no_grad():
193
  outputs = self.llm_model.generate(
194
  **inputs,
195
- max_new_tokens=max_new_tokens,
196
  temperature=temperature,
197
  top_p=top_p,
198
  do_sample=True,
 
199
  pad_token_id=self.tokenizer.eos_token_id,
200
  eos_token_id=self.tokenizer.eos_token_id,
201
  )
202
 
 
 
203
  # Decodificar
204
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
205
 
206
- # Extraer respuesta
207
- if "RESPUESTA:" in response:
208
- response = response.split("RESPUESTA:")[-1].strip()
 
209
 
210
- return response
211
 
212
  except Exception as e:
 
213
  return f"Error generando respuesta: {str(e)}"
214
 
215
  def query(
 
138
  self,
139
  query: str,
140
  context_docs: List[Dict],
141
+ max_new_tokens: int = 512,
142
  temperature: float = 0.7,
143
  top_p: float = 0.9
144
  ) -> str:
145
  """Generar respuesta con Salamandra."""
146
 
147
+ # Construir contexto (limitado para evitar timeouts)
148
  context_text = "\n\n---\n\n".join([
149
+ f"[Doc: {doc['filename'][:30]}]\n{doc['content'][:1000]}"
150
+ for doc in context_docs[:3] # Solo top 3 docs
151
  ])
152
 
153
+ # Prompt optimizado (más corto)
154
+ prompt = f"""Eres ALIA, asistente de turismo de la Comunidad Valenciana.
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ Responde basandote en estos documentos:
157
 
158
  {context_text}
159
 
160
+ PREGUNTA: {query}
 
161
 
162
+ RESPUESTA (sé conciso):"""
163
 
164
  # Tokenizar
165
  inputs = self.tokenizer(
 
173
  if self.device == 'cuda':
174
  inputs = {k: v.cuda() for k, v in inputs.items()}
175
 
176
+ # Generar con parametros optimizados
177
  try:
178
+ print(f"[GENERATE] Iniciando generacion en {self.device}...")
179
+
180
  with torch.no_grad():
181
  outputs = self.llm_model.generate(
182
  **inputs,
183
+ max_new_tokens=min(max_new_tokens, 256), # Limitar a 256 tokens max
184
  temperature=temperature,
185
  top_p=top_p,
186
  do_sample=True,
187
+ num_beams=1, # Greedy decoding para velocidad
188
  pad_token_id=self.tokenizer.eos_token_id,
189
  eos_token_id=self.tokenizer.eos_token_id,
190
  )
191
 
192
+ print(f"[GENERATE] Generacion completada")
193
+
194
  # Decodificar
195
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
196
 
197
+ # Extraer solo la respuesta generada
198
+ if "RESPUESTA" in response:
199
+ response = response.split("RESPUESTA")[-1].strip()
200
+ response = response.replace("(sé conciso):", "").strip()
201
 
202
+ return response[:2000] # Limitar largo de respuesta
203
 
204
  except Exception as e:
205
+ print(f"[ERROR] Error en generacion: {str(e)}")
206
  return f"Error generando respuesta: {str(e)}"
207
 
208
  def query(