""" Sistema RAG simplificado para Hugging Face Spaces Version optimizada con Salamandra 7B Instruct """ import os from typing import List, Dict from dataclasses import dataclass import torch from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient from transformers import AutoModelForCausalLM, AutoTokenizer import time @dataclass class RAGResult: """Resultado de una consulta RAG.""" query: str answer: str sources: List[Dict] retrieval_time: float generation_time: float total_time: float class RAGLLMSystem: """Sistema RAG + Salamandra LLM.""" def __init__(self): """Inicializar sistema.""" # Configuracion desde variables de entorno self.qdrant_url = os.getenv("QDRANT_URL") self.qdrant_api_key = os.getenv("QDRANT_API_KEY") self.qdrant_collection = os.getenv("QDRANT_COLLECTION", "alia_turismo_docs") # Debug: verificar que las variables existen print(f"[DEBUG] QDRANT_URL configurado: {self.qdrant_url is not None}") print(f"[DEBUG] QDRANT_API_KEY configurado: {self.qdrant_api_key is not None}") print(f"[DEBUG] QDRANT_COLLECTION: {self.qdrant_collection}") # Modelo LLM self.llm_model_name = "BSC-LT/salamandra-7b-instruct" # Modelo de embeddings self.embedding_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" # Detectar dispositivo self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"[RAG] Dispositivo: {self.device}") # Inicializar componentes self._init_qdrant_client() self._init_embedding_model() self._init_salamandra_model() def _init_qdrant_client(self): """Inicializar cliente de Qdrant.""" print(f"[RAG] Conectando a Qdrant Cloud...") self.qdrant_client = QdrantClient( url=self.qdrant_url, api_key=self.qdrant_api_key ) print(f"[RAG] Conectado a Qdrant") def _init_embedding_model(self): """Inicializar modelo de embeddings.""" print(f"[RAG] Cargando modelo de embeddings...") self.embedding_model = SentenceTransformer( self.embedding_model_name, device=self.device ) print(f"[RAG] Embeddings cargados") def _init_salamandra_model(self): """Inicializar Salamandra 7B Instruct con cuantizacion 8-bit.""" print(f"[RAG] Cargando Salamandra 7B Instruct (8-bit cuantizado)...") # Cargar tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name) # Cargar modelo con cuantizacion 8-bit para ahorrar memoria if self.device == 'cuda': self.llm_model = AutoModelForCausalLM.from_pretrained( self.llm_model_name, load_in_8bit=True, device_map="auto", low_cpu_mem_usage=True ) print(f"[RAG] Salamandra cargado en GPU (8-bit)") else: self.llm_model = AutoModelForCausalLM.from_pretrained( self.llm_model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True ) print(f"[RAG] Salamandra cargado en CPU") self.llm_model.eval() def retrieve_context( self, query: str, top_k: int = 5, score_threshold: float = 0.6 ) -> List[Dict]: """Recuperar documentos relevantes.""" # Generar embedding query_embedding = self.embedding_model.encode( query, convert_to_numpy=True ) # Buscar en Qdrant results = self.qdrant_client.query_points( collection_name=self.qdrant_collection, query=query_embedding.tolist(), limit=top_k ).points # Filtrar y formatear documents = [] for result in results: if result.score >= score_threshold: documents.append({ 'content': result.payload.get('full_content', ''), 'filename': result.payload.get('filename', ''), 'category': result.payload.get('category', ''), 'score': result.score, 'id': result.id }) return documents def generate_answer( self, query: str, context_docs: List[Dict], max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9 ) -> str: """Generar respuesta con Salamandra.""" # Construir contexto (limitado para evitar timeouts) context_text = "\n\n---\n\n".join([ f"[Doc: {doc['filename'][:30]}]\n{doc['content'][:1000]}" for doc in context_docs[:3] # Solo top 3 docs ]) # Prompt optimizado (más corto) prompt = f"""Eres ALIA, asistente de turismo de la Comunidad Valenciana. Responde basandote en estos documentos: {context_text} PREGUNTA: {query} RESPUESTA (sé conciso):""" # Tokenizar inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=4096 ) # Mover a dispositivo if self.device == 'cuda': inputs = {k: v.cuda() for k, v in inputs.items()} # Generar con parametros optimizados try: print(f"[GENERATE] Iniciando generacion en {self.device}...") with torch.no_grad(): outputs = self.llm_model.generate( **inputs, max_new_tokens=min(max_new_tokens, 256), # Limitar a 256 tokens max temperature=temperature, top_p=top_p, do_sample=True, num_beams=1, # Greedy decoding para velocidad pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, ) print(f"[GENERATE] Generacion completada") # Decodificar response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extraer solo la respuesta generada if "RESPUESTA" in response: response = response.split("RESPUESTA")[-1].strip() response = response.replace("(sé conciso):", "").strip() return response[:2000] # Limitar largo de respuesta except Exception as e: print(f"[ERROR] Error en generacion: {str(e)}") return f"Error generando respuesta: {str(e)}" def query( self, question: str, top_k: int = 5, score_threshold: float = 0.6, max_new_tokens: int = 1024, temperature: float = 0.7 ) -> RAGResult: """Procesar consulta completa.""" start_time = time.time() # Recuperar contexto retrieval_start = time.time() context_docs = self.retrieve_context(question, top_k, score_threshold) retrieval_time = time.time() - retrieval_start if not context_docs: return RAGResult( query=question, answer="No se encontraron documentos relevantes para responder tu pregunta.", sources=[], retrieval_time=retrieval_time, generation_time=0, total_time=time.time() - start_time ) # Generar respuesta generation_start = time.time() answer = self.generate_answer( question, context_docs, max_new_tokens=max_new_tokens, temperature=temperature ) generation_time = time.time() - generation_start # Preparar resultado sources = [{ 'filename': doc['filename'], 'category': doc['category'], 'score': doc['score'] } for doc in context_docs] return RAGResult( query=question, answer=answer, sources=sources, retrieval_time=retrieval_time, generation_time=generation_time, total_time=time.time() - start_time )