Spaces:
Sleeping
Sleeping
| """ | |
| Sistema RAG simplificado para Hugging Face Spaces | |
| Version optimizada con Salamandra 7B Instruct | |
| """ | |
| import os | |
| from typing import List, Dict | |
| from dataclasses import dataclass | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from qdrant_client import QdrantClient | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import time | |
| class RAGResult: | |
| """Resultado de una consulta RAG.""" | |
| query: str | |
| answer: str | |
| sources: List[Dict] | |
| retrieval_time: float | |
| generation_time: float | |
| total_time: float | |
| class RAGLLMSystem: | |
| """Sistema RAG + Salamandra LLM.""" | |
| def __init__(self): | |
| """Inicializar sistema.""" | |
| # Configuracion desde variables de entorno | |
| self.qdrant_url = os.getenv("QDRANT_URL") | |
| self.qdrant_api_key = os.getenv("QDRANT_API_KEY") | |
| self.qdrant_collection = os.getenv("QDRANT_COLLECTION", "alia_turismo_docs") | |
| # Debug: verificar que las variables existen | |
| print(f"[DEBUG] QDRANT_URL configurado: {self.qdrant_url is not None}") | |
| print(f"[DEBUG] QDRANT_API_KEY configurado: {self.qdrant_api_key is not None}") | |
| print(f"[DEBUG] QDRANT_COLLECTION: {self.qdrant_collection}") | |
| # Modelo LLM | |
| self.llm_model_name = "BSC-LT/salamandra-7b-instruct" | |
| # Modelo de embeddings | |
| self.embedding_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| # Detectar dispositivo | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"[RAG] Dispositivo: {self.device}") | |
| # Inicializar componentes | |
| self._init_qdrant_client() | |
| self._init_embedding_model() | |
| self._init_salamandra_model() | |
| def _init_qdrant_client(self): | |
| """Inicializar cliente de Qdrant.""" | |
| print(f"[RAG] Conectando a Qdrant Cloud...") | |
| self.qdrant_client = QdrantClient( | |
| url=self.qdrant_url, | |
| api_key=self.qdrant_api_key | |
| ) | |
| print(f"[RAG] Conectado a Qdrant") | |
| def _init_embedding_model(self): | |
| """Inicializar modelo de embeddings.""" | |
| print(f"[RAG] Cargando modelo de embeddings...") | |
| self.embedding_model = SentenceTransformer( | |
| self.embedding_model_name, | |
| device=self.device | |
| ) | |
| print(f"[RAG] Embeddings cargados") | |
| def _init_salamandra_model(self): | |
| """Inicializar Salamandra 7B Instruct con cuantizacion 8-bit.""" | |
| print(f"[RAG] Cargando Salamandra 7B Instruct (8-bit cuantizado)...") | |
| # Cargar tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name) | |
| # Cargar modelo con cuantizacion 8-bit para ahorrar memoria | |
| if self.device == 'cuda': | |
| self.llm_model = AutoModelForCausalLM.from_pretrained( | |
| self.llm_model_name, | |
| load_in_8bit=True, | |
| device_map="auto", | |
| low_cpu_mem_usage=True | |
| ) | |
| print(f"[RAG] Salamandra cargado en GPU (8-bit)") | |
| else: | |
| self.llm_model = AutoModelForCausalLM.from_pretrained( | |
| self.llm_model_name, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| print(f"[RAG] Salamandra cargado en CPU") | |
| self.llm_model.eval() | |
| def retrieve_context( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| score_threshold: float = 0.6 | |
| ) -> List[Dict]: | |
| """Recuperar documentos relevantes.""" | |
| # Generar embedding | |
| query_embedding = self.embedding_model.encode( | |
| query, | |
| convert_to_numpy=True | |
| ) | |
| # Buscar en Qdrant | |
| results = self.qdrant_client.query_points( | |
| collection_name=self.qdrant_collection, | |
| query=query_embedding.tolist(), | |
| limit=top_k | |
| ).points | |
| # Filtrar y formatear | |
| documents = [] | |
| for result in results: | |
| if result.score >= score_threshold: | |
| documents.append({ | |
| 'content': result.payload.get('full_content', ''), | |
| 'filename': result.payload.get('filename', ''), | |
| 'category': result.payload.get('category', ''), | |
| 'score': result.score, | |
| 'id': result.id | |
| }) | |
| return documents | |
| def generate_answer( | |
| self, | |
| query: str, | |
| context_docs: List[Dict], | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9 | |
| ) -> str: | |
| """Generar respuesta con Salamandra.""" | |
| # Construir contexto (limitado para evitar timeouts) | |
| context_text = "\n\n---\n\n".join([ | |
| f"[Doc: {doc['filename'][:30]}]\n{doc['content'][:1000]}" | |
| for doc in context_docs[:3] # Solo top 3 docs | |
| ]) | |
| # Prompt optimizado (más corto) | |
| prompt = f"""Eres ALIA, asistente de turismo de la Comunidad Valenciana. | |
| Responde basandote en estos documentos: | |
| {context_text} | |
| PREGUNTA: {query} | |
| RESPUESTA (sé conciso):""" | |
| # Tokenizar | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=4096 | |
| ) | |
| # Mover a dispositivo | |
| if self.device == 'cuda': | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| # Generar con parametros optimizados | |
| try: | |
| print(f"[GENERATE] Iniciando generacion en {self.device}...") | |
| with torch.no_grad(): | |
| outputs = self.llm_model.generate( | |
| **inputs, | |
| max_new_tokens=min(max_new_tokens, 256), # Limitar a 256 tokens max | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| num_beams=1, # Greedy decoding para velocidad | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| print(f"[GENERATE] Generacion completada") | |
| # Decodificar | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extraer solo la respuesta generada | |
| if "RESPUESTA" in response: | |
| response = response.split("RESPUESTA")[-1].strip() | |
| response = response.replace("(sé conciso):", "").strip() | |
| return response[:2000] # Limitar largo de respuesta | |
| except Exception as e: | |
| print(f"[ERROR] Error en generacion: {str(e)}") | |
| return f"Error generando respuesta: {str(e)}" | |
| def query( | |
| self, | |
| question: str, | |
| top_k: int = 5, | |
| score_threshold: float = 0.6, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.7 | |
| ) -> RAGResult: | |
| """Procesar consulta completa.""" | |
| start_time = time.time() | |
| # Recuperar contexto | |
| retrieval_start = time.time() | |
| context_docs = self.retrieve_context(question, top_k, score_threshold) | |
| retrieval_time = time.time() - retrieval_start | |
| if not context_docs: | |
| return RAGResult( | |
| query=question, | |
| answer="No se encontraron documentos relevantes para responder tu pregunta.", | |
| sources=[], | |
| retrieval_time=retrieval_time, | |
| generation_time=0, | |
| total_time=time.time() - start_time | |
| ) | |
| # Generar respuesta | |
| generation_start = time.time() | |
| answer = self.generate_answer( | |
| question, | |
| context_docs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature | |
| ) | |
| generation_time = time.time() - generation_start | |
| # Preparar resultado | |
| sources = [{ | |
| 'filename': doc['filename'], | |
| 'category': doc['category'], | |
| 'score': doc['score'] | |
| } for doc in context_docs] | |
| return RAGResult( | |
| query=question, | |
| answer=answer, | |
| sources=sources, | |
| retrieval_time=retrieval_time, | |
| generation_time=generation_time, | |
| total_time=time.time() - start_time | |
| ) | |