perellorets commited on
Commit
ffbd730
·
verified ·
1 Parent(s): 2a0c31e

Update rag_system.py

Browse files
Files changed (1) hide show
  1. rag_system.py +261 -261
rag_system.py CHANGED
@@ -1,261 +1,261 @@
1
- """
2
- Sistema RAG simplificado para Hugging Face Spaces
3
- Version optimizada con Salamandra 7B Instruct
4
- """
5
-
6
- import os
7
- from typing import List, Dict
8
- from dataclasses import dataclass
9
- import torch
10
- from sentence_transformers import SentenceTransformer
11
- from qdrant_client import QdrantClient
12
- from transformers import AutoModelForCausalLM, AutoTokenizer
13
- import time
14
-
15
-
16
- @dataclass
17
- class RAGResult:
18
- """Resultado de una consulta RAG."""
19
- query: str
20
- answer: str
21
- sources: List[Dict]
22
- retrieval_time: float
23
- generation_time: float
24
- total_time: float
25
-
26
-
27
- class RAGLLMSystem:
28
- """Sistema RAG + Salamandra LLM."""
29
-
30
- def __init__(self):
31
- """Inicializar sistema."""
32
-
33
- # Configuracion desde variables de entorno
34
- self.qdrant_url = os.getenv("QDRANT_URL")
35
- self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
36
- self.qdrant_collection = os.getenv("QDRANT_COLLECTION", "alia_turismo_docs")
37
-
38
- # Modelo LLM
39
- self.llm_model_name = "BSC-LT/salamandra-7b-instruct"
40
-
41
- # Modelo de embeddings
42
- self.embedding_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
43
-
44
- # Detectar dispositivo
45
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
46
- print(f"[RAG] Dispositivo: {self.device}")
47
-
48
- # Inicializar componentes
49
- self._init_qdrant_client()
50
- self._init_embedding_model()
51
- self._init_salamandra_model()
52
-
53
- def _init_qdrant_client(self):
54
- """Inicializar cliente de Qdrant."""
55
- print(f"[RAG] Conectando a Qdrant Cloud...")
56
- self.qdrant_client = QdrantClient(
57
- url=self.qdrant_url,
58
- api_key=self.qdrant_api_key
59
- )
60
- print(f"[RAG] Conectado a Qdrant")
61
-
62
- def _init_embedding_model(self):
63
- """Inicializar modelo de embeddings."""
64
- print(f"[RAG] Cargando modelo de embeddings...")
65
- self.embedding_model = SentenceTransformer(
66
- self.embedding_model_name,
67
- device=self.device
68
- )
69
- print(f"[RAG] Embeddings cargados")
70
-
71
- def _init_salamandra_model(self):
72
- """Inicializar Salamandra 7B Instruct."""
73
- print(f"[RAG] Cargando Salamandra 7B Instruct...")
74
-
75
- # Cargar tokenizer
76
- self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
77
-
78
- # Cargar modelo
79
- if self.device == 'cuda':
80
- self.llm_model = AutoModelForCausalLM.from_pretrained(
81
- self.llm_model_name,
82
- torch_dtype=torch.float16,
83
- device_map="auto",
84
- low_cpu_mem_usage=True
85
- )
86
- print(f"[RAG] Salamandra cargado en GPU")
87
- else:
88
- self.llm_model = AutoModelForCausalLM.from_pretrained(
89
- self.llm_model_name,
90
- torch_dtype=torch.float32,
91
- low_cpu_mem_usage=True
92
- )
93
- print(f"[RAG] Salamandra cargado en CPU")
94
-
95
- self.llm_model.eval()
96
-
97
- def retrieve_context(
98
- self,
99
- query: str,
100
- top_k: int = 5,
101
- score_threshold: float = 0.6
102
- ) -> List[Dict]:
103
- """Recuperar documentos relevantes."""
104
-
105
- # Generar embedding
106
- query_embedding = self.embedding_model.encode(
107
- query,
108
- convert_to_numpy=True
109
- )
110
-
111
- # Buscar en Qdrant
112
- results = self.qdrant_client.query_points(
113
- collection_name=self.qdrant_collection,
114
- query=query_embedding.tolist(),
115
- limit=top_k
116
- ).points
117
-
118
- # Filtrar y formatear
119
- documents = []
120
- for result in results:
121
- if result.score >= score_threshold:
122
- documents.append({
123
- 'content': result.payload.get('full_content', ''),
124
- 'filename': result.payload.get('filename', ''),
125
- 'category': result.payload.get('category', ''),
126
- 'score': result.score,
127
- 'id': result.id
128
- })
129
-
130
- return documents
131
-
132
- def generate_answer(
133
- self,
134
- query: str,
135
- context_docs: List[Dict],
136
- max_new_tokens: int = 1024,
137
- temperature: float = 0.7,
138
- top_p: float = 0.9
139
- ) -> str:
140
- """Generar respuesta con Salamandra."""
141
-
142
- # Construir contexto
143
- context_text = "\n\n---\n\n".join([
144
- f"[Documento: {doc['filename']}]\n{doc['content'][:2000]}"
145
- for doc in context_docs
146
- ])
147
-
148
- # Prompt
149
- prompt = f"""Eres ALIA, un asistente experto en planificacion estrategica turistica de la Comunidad Valenciana.
150
-
151
- Tu funcion es ayudar a funcionarios publicos, tecnicos de turismo y responsables de destinos turisticos a:
152
- - Comprender y aplicar estrategias de planes turisticos
153
- - Obtener informacion sobre mejores practicas en turismo sostenible
154
- - Consultar casos de exito de otros municipios
155
- - Disenar e implementar planes estrategicos turisticos
156
-
157
- INSTRUCCIONES:
158
- 1. Responde SIEMPRE basandote en los documentos proporcionados
159
- 2. Si la informacion no esta en los documentos, indica claramente que no la tienes
160
- 3. Cita los documentos fuente cuando sea relevante
161
- 4. Usa un tono profesional pero accesible
162
- 5. Estructura tus respuestas de forma clara con bullets o numeracion cuando sea apropiado
163
-
164
- CONTEXTO (Documentos de planes estrategicos de turismo):
165
-
166
- {context_text}
167
-
168
- PREGUNTA DEL USUARIO:
169
- {query}
170
-
171
- RESPUESTA:"""
172
-
173
- # Tokenizar
174
- inputs = self.tokenizer(
175
- prompt,
176
- return_tensors="pt",
177
- truncation=True,
178
- max_length=4096
179
- )
180
-
181
- # Mover a dispositivo
182
- if self.device == 'cuda':
183
- inputs = {k: v.cuda() for k, v in inputs.items()}
184
-
185
- # Generar
186
- try:
187
- with torch.no_grad():
188
- outputs = self.llm_model.generate(
189
- **inputs,
190
- max_new_tokens=max_new_tokens,
191
- temperature=temperature,
192
- top_p=top_p,
193
- do_sample=True,
194
- pad_token_id=self.tokenizer.eos_token_id,
195
- eos_token_id=self.tokenizer.eos_token_id,
196
- )
197
-
198
- # Decodificar
199
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
200
-
201
- # Extraer respuesta
202
- if "RESPUESTA:" in response:
203
- response = response.split("RESPUESTA:")[-1].strip()
204
-
205
- return response
206
-
207
- except Exception as e:
208
- return f"Error generando respuesta: {str(e)}"
209
-
210
- def query(
211
- self,
212
- question: str,
213
- top_k: int = 5,
214
- score_threshold: float = 0.6,
215
- max_new_tokens: int = 1024,
216
- temperature: float = 0.7
217
- ) -> RAGResult:
218
- """Procesar consulta completa."""
219
-
220
- start_time = time.time()
221
-
222
- # Recuperar contexto
223
- retrieval_start = time.time()
224
- context_docs = self.retrieve_context(question, top_k, score_threshold)
225
- retrieval_time = time.time() - retrieval_start
226
-
227
- if not context_docs:
228
- return RAGResult(
229
- query=question,
230
- answer="No se encontraron documentos relevantes para responder tu pregunta.",
231
- sources=[],
232
- retrieval_time=retrieval_time,
233
- generation_time=0,
234
- total_time=time.time() - start_time
235
- )
236
-
237
- # Generar respuesta
238
- generation_start = time.time()
239
- answer = self.generate_answer(
240
- question,
241
- context_docs,
242
- max_new_tokens=max_new_tokens,
243
- temperature=temperature
244
- )
245
- generation_time = time.time() - generation_start
246
-
247
- # Preparar resultado
248
- sources = [{
249
- 'filename': doc['filename'],
250
- 'category': doc['category'],
251
- 'score': doc['score']
252
- } for doc in context_docs]
253
-
254
- return RAGResult(
255
- query=question,
256
- answer=answer,
257
- sources=sources,
258
- retrieval_time=retrieval_time,
259
- generation_time=generation_time,
260
- total_time=time.time() - start_time
261
- )
 
1
+ """
2
+ Sistema RAG simplificado para Hugging Face Spaces
3
+ Version optimizada con Salamandra 7B Instruct
4
+ """
5
+
6
+ import os
7
+ from typing import List, Dict
8
+ from dataclasses import dataclass
9
+ import torch
10
+ from sentence_transformers import SentenceTransformer
11
+ from qdrant_client import QdrantClient
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ import time
14
+
15
+
16
+ @dataclass
17
+ class RAGResult:
18
+ """Resultado de una consulta RAG."""
19
+ query: str
20
+ answer: str
21
+ sources: List[Dict]
22
+ retrieval_time: float
23
+ generation_time: float
24
+ total_time: float
25
+
26
+
27
+ class RAGLLMSystem:
28
+ """Sistema RAG + Salamandra LLM."""
29
+
30
+ def __init__(self):
31
+ """Inicializar sistema."""
32
+
33
+ # Configuracion desde variables de entorno
34
+ self.qdrant_url = os.getenv("QDRANT_URL")
35
+ self.qdrant_api_key = os.getenv("QDRANT_API_KEY")
36
+ self.qdrant_collection = os.getenv("QDRANT_COLLECTION", "alia_turismo_docs")
37
+
38
+ # Modelo LLM
39
+ self.llm_model_name = "BSC-LT/salamandra-7b-instruct"
40
+
41
+ # Modelo de embeddings
42
+ self.embedding_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
43
+
44
+ # Detectar dispositivo
45
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
46
+ print(f"[RAG] Dispositivo: {self.device}")
47
+
48
+ # Inicializar componentes
49
+ self._init_qdrant_client()
50
+ self._init_embedding_model()
51
+ self._init_salamandra_model()
52
+
53
+ def _init_qdrant_client(self):
54
+ """Inicializar cliente de Qdrant."""
55
+ print(f"[RAG] Conectando a Qdrant Cloud...")
56
+ self.qdrant_client = QdrantClient(
57
+ url=self.qdrant_url,
58
+ api_key=self.qdrant_api_key
59
+ )
60
+ print(f"[RAG] Conectado a Qdrant")
61
+
62
+ def _init_embedding_model(self):
63
+ """Inicializar modelo de embeddings."""
64
+ print(f"[RAG] Cargando modelo de embeddings...")
65
+ self.embedding_model = SentenceTransformer(
66
+ self.embedding_model_name,
67
+ device=self.device
68
+ )
69
+ print(f"[RAG] Embeddings cargados")
70
+
71
+ def _init_salamandra_model(self):
72
+ """Inicializar Salamandra 7B Instruct con cuantizacion 8-bit."""
73
+ print(f"[RAG] Cargando Salamandra 7B Instruct (8-bit cuantizado)...")
74
+
75
+ # Cargar tokenizer
76
+ self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name)
77
+
78
+ # Cargar modelo con cuantizacion 8-bit para ahorrar memoria
79
+ if self.device == 'cuda':
80
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
81
+ self.llm_model_name,
82
+ load_in_8bit=True,
83
+ device_map="auto",
84
+ low_cpu_mem_usage=True
85
+ )
86
+ print(f"[RAG] Salamandra cargado en GPU (8-bit)")
87
+ else:
88
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
89
+ self.llm_model_name,
90
+ torch_dtype=torch.float32,
91
+ low_cpu_mem_usage=True
92
+ )
93
+ print(f"[RAG] Salamandra cargado en CPU")
94
+
95
+ self.llm_model.eval()
96
+
97
+ def retrieve_context(
98
+ self,
99
+ query: str,
100
+ top_k: int = 5,
101
+ score_threshold: float = 0.6
102
+ ) -> List[Dict]:
103
+ """Recuperar documentos relevantes."""
104
+
105
+ # Generar embedding
106
+ query_embedding = self.embedding_model.encode(
107
+ query,
108
+ convert_to_numpy=True
109
+ )
110
+
111
+ # Buscar en Qdrant
112
+ results = self.qdrant_client.query_points(
113
+ collection_name=self.qdrant_collection,
114
+ query=query_embedding.tolist(),
115
+ limit=top_k
116
+ ).points
117
+
118
+ # Filtrar y formatear
119
+ documents = []
120
+ for result in results:
121
+ if result.score >= score_threshold:
122
+ documents.append({
123
+ 'content': result.payload.get('full_content', ''),
124
+ 'filename': result.payload.get('filename', ''),
125
+ 'category': result.payload.get('category', ''),
126
+ 'score': result.score,
127
+ 'id': result.id
128
+ })
129
+
130
+ return documents
131
+
132
+ def generate_answer(
133
+ self,
134
+ query: str,
135
+ context_docs: List[Dict],
136
+ max_new_tokens: int = 1024,
137
+ temperature: float = 0.7,
138
+ top_p: float = 0.9
139
+ ) -> str:
140
+ """Generar respuesta con Salamandra."""
141
+
142
+ # Construir contexto
143
+ context_text = "\n\n---\n\n".join([
144
+ f"[Documento: {doc['filename']}]\n{doc['content'][:2000]}"
145
+ for doc in context_docs
146
+ ])
147
+
148
+ # Prompt
149
+ prompt = f"""Eres ALIA, un asistente experto en planificacion estrategica turistica de la Comunidad Valenciana.
150
+
151
+ Tu funcion es ayudar a funcionarios publicos, tecnicos de turismo y responsables de destinos turisticos a:
152
+ - Comprender y aplicar estrategias de planes turisticos
153
+ - Obtener informacion sobre mejores practicas en turismo sostenible
154
+ - Consultar casos de exito de otros municipios
155
+ - Disenar e implementar planes estrategicos turisticos
156
+
157
+ INSTRUCCIONES:
158
+ 1. Responde SIEMPRE basandote en los documentos proporcionados
159
+ 2. Si la informacion no esta en los documentos, indica claramente que no la tienes
160
+ 3. Cita los documentos fuente cuando sea relevante
161
+ 4. Usa un tono profesional pero accesible
162
+ 5. Estructura tus respuestas de forma clara con bullets o numeracion cuando sea apropiado
163
+
164
+ CONTEXTO (Documentos de planes estrategicos de turismo):
165
+
166
+ {context_text}
167
+
168
+ PREGUNTA DEL USUARIO:
169
+ {query}
170
+
171
+ RESPUESTA:"""
172
+
173
+ # Tokenizar
174
+ inputs = self.tokenizer(
175
+ prompt,
176
+ return_tensors="pt",
177
+ truncation=True,
178
+ max_length=4096
179
+ )
180
+
181
+ # Mover a dispositivo
182
+ if self.device == 'cuda':
183
+ inputs = {k: v.cuda() for k, v in inputs.items()}
184
+
185
+ # Generar
186
+ try:
187
+ with torch.no_grad():
188
+ outputs = self.llm_model.generate(
189
+ **inputs,
190
+ max_new_tokens=max_new_tokens,
191
+ temperature=temperature,
192
+ top_p=top_p,
193
+ do_sample=True,
194
+ pad_token_id=self.tokenizer.eos_token_id,
195
+ eos_token_id=self.tokenizer.eos_token_id,
196
+ )
197
+
198
+ # Decodificar
199
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
200
+
201
+ # Extraer respuesta
202
+ if "RESPUESTA:" in response:
203
+ response = response.split("RESPUESTA:")[-1].strip()
204
+
205
+ return response
206
+
207
+ except Exception as e:
208
+ return f"Error generando respuesta: {str(e)}"
209
+
210
+ def query(
211
+ self,
212
+ question: str,
213
+ top_k: int = 5,
214
+ score_threshold: float = 0.6,
215
+ max_new_tokens: int = 1024,
216
+ temperature: float = 0.7
217
+ ) -> RAGResult:
218
+ """Procesar consulta completa."""
219
+
220
+ start_time = time.time()
221
+
222
+ # Recuperar contexto
223
+ retrieval_start = time.time()
224
+ context_docs = self.retrieve_context(question, top_k, score_threshold)
225
+ retrieval_time = time.time() - retrieval_start
226
+
227
+ if not context_docs:
228
+ return RAGResult(
229
+ query=question,
230
+ answer="No se encontraron documentos relevantes para responder tu pregunta.",
231
+ sources=[],
232
+ retrieval_time=retrieval_time,
233
+ generation_time=0,
234
+ total_time=time.time() - start_time
235
+ )
236
+
237
+ # Generar respuesta
238
+ generation_start = time.time()
239
+ answer = self.generate_answer(
240
+ question,
241
+ context_docs,
242
+ max_new_tokens=max_new_tokens,
243
+ temperature=temperature
244
+ )
245
+ generation_time = time.time() - generation_start
246
+
247
+ # Preparar resultado
248
+ sources = [{
249
+ 'filename': doc['filename'],
250
+ 'category': doc['category'],
251
+ 'score': doc['score']
252
+ } for doc in context_docs]
253
+
254
+ return RAGResult(
255
+ query=question,
256
+ answer=answer,
257
+ sources=sources,
258
+ retrieval_time=retrieval_time,
259
+ generation_time=generation_time,
260
+ total_time=time.time() - start_time
261
+ )