fintech-coop-api / api /services /rag_service.py
TomacGonz's picture
Update api/services/rag_service.py
15cb869 verified
import os
import numpy as np
from typing import List, Dict, Optional
import logging
# Configure logging for Hugging Face Spaces
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CooperativaAdvancedRAG:
_instance = None
_models_loaded = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, 'initialized') and self.initialized:
return
self.initialized = True
self._models_loaded = False
logger.info("--- Inicializando RAG Service (carga perezosa) ---")
# Get the correct paths for Hugging Face Spaces
self._setup_paths()
def _setup_paths(self):
"""Setup paths for Hugging Face Spaces"""
# In Hugging Face Spaces, the current working directory is the app root
self.backend_dir = os.getcwd()
# Check for FAISS index in common locations
possible_paths = [
os.path.join(self.backend_dir, "faiss_index"),
os.path.join(self.backend_dir, "backend", "faiss_index"),
os.path.join(os.path.dirname(self.backend_dir), "faiss_index"),
]
self.persist_directory = None
for path in possible_paths:
if os.path.exists(path):
self.persist_directory = path
logger.info(f"FAISS index encontrado en: {path}")
break
# Get API token from environment (Hugging Face Spaces secrets)
self.hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN")
if not self.hf_token:
logger.warning("HUGGINGFACEHUB_API_TOKEN no encontrado. El LLM no funcionar谩 correctamente.")
else:
logger.info("Token de Hugging Face encontrado")
def _load_models(self):
"""Lazy loading of models - only called when needed"""
if self._models_loaded:
return
logger.info("--- Cargando modelos de IA a la memoria ---")
try:
# Import here to avoid loading at startup
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from sentence_transformers import CrossEncoder
# Check if FAISS index exists
if not self.persist_directory or not os.path.exists(self.persist_directory):
error_msg = f"FAISS index no encontrado en: {self.persist_directory}"
logger.error(error_msg)
raise RuntimeError(error_msg)
# -------------------------
# EMBEDDINGS
# -------------------------
logger.info("Cargando modelo de embeddings...")
self.embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
# -------------------------
# VECTOR DATABASE
# -------------------------
logger.info("Cargando FAISS index...")
self.db = FAISS.load_local(
self.persist_directory,
self.embeddings,
allow_dangerous_deserialization=True,
)
# -------------------------
# CROSS ENCODER (RERANK)
# -------------------------
logger.info("Cargando CrossEncoder...")
self.cross_encoder = CrossEncoder(
"cross-encoder/mmarco-mMiniLMv2-L12-H384-v1",
device='cpu'
)
# -------------------------
# LLM (solo si hay token)
# -------------------------
if self.hf_token:
logger.info("Inicializando HuggingFaceEndpoint...")
self.llm = HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3",
huggingfacehub_api_token=self.hf_token,
task="text-generation",
max_new_tokens=512,
temperature=0.1,
do_sample=True,
top_p=0.95,
typical_p=0.95,
repetition_penalty=1.1,
timeout=120,
)
# Test the connection
try:
test_response = self.llm.invoke("Hola")
logger.info("LLM inicializado correctamente")
except Exception as e:
logger.error(f"Error al probar LLM: {e}")
self.llm = None
else:
logger.warning("No hay token disponible - LLM no inicializado")
self.llm = None
self._models_loaded = True
logger.info("--- Sistema RAG listo para recibir consultas ---")
except Exception as e:
logger.error(f"Error cr铆tico cargando modelos: {e}")
raise
# --------------------------------------------------
# MAIN QUERY
# --------------------------------------------------
def query(
self,
question: str,
chat_history: Optional[List[Dict[str, str]]] = None,
top_k_initial: int = 25,
top_k_final: int = 3,
) -> str:
# Load models on first query
try:
self._load_models()
except Exception as e:
return f"Error inicializando el sistema: {str(e)}"
# Check if LLM is available
if not self.llm:
return "Error: Token de Hugging Face no configurado. Por favor, configura HUGGINGFACEHUB_API_TOKEN en los secretos del Space."
# -------------------------
# CHAT HISTORY
# -------------------------
history_text = ""
if chat_history:
for turn in chat_history[-5:]:
role = "Usuario" if turn.get("role") == "user" else "Asistente"
content = turn.get("content", "")
if content:
history_text += f"{role}: {content}\n"
standalone_question = question
# -------------------------
# QUESTION REWRITE (solo si hay historial)
# -------------------------
if history_text.strip():
rewrite_prompt = f"""<s>[INST] Reformula la siguiente pregunta para que sea independiente del historial de la conversaci贸n.
Historial:
{history_text}
Pregunta actual:
{question}
Pregunta reformulada (solo la pregunta, sin explicaciones): [/INST]"""
try:
rewritten = self.llm.invoke(rewrite_prompt).strip()
if rewritten and len(rewritten) > 10:
standalone_question = rewritten
logger.info(f"Pregunta reformulada: {standalone_question}")
except Exception as e:
logger.error(f"Error en rewrite: {e}")
# Continue with original question
# -------------------------
# FAISS SEARCH
# -------------------------
try:
initial_docs = self.db.similarity_search_with_score(
standalone_question,
k=top_k_initial
)
# Filter by score (lower is better for FAISS)
valid_docs = [
doc for doc, score in initial_docs
if score < 2.0 # Ajusta este umbral seg煤n necesidad
]
if not valid_docs:
return "No encontr茅 informaci贸n relevante en los documentos disponibles."
except Exception as e:
logger.error(f"Error en b煤squeda FAISS: {e}")
return f"Error en la b煤squeda: {str(e)}"
# -------------------------
# CROSS ENCODER RERANK
# -------------------------
try:
cross_inputs = [
[standalone_question, doc.page_content]
for doc in valid_docs
]
scores = self.cross_encoder.predict(cross_inputs)
# Sort by score (higher is better for cross-encoder)
sorted_idx = np.argsort(scores)[::-1]
top_docs = [
valid_docs[i]
for i in sorted_idx[:top_k_final]
]
except Exception as e:
logger.error(f"Error en reranking: {e}")
# Fallback to use valid_docs without reranking
top_docs = valid_docs[:top_k_final]
# -------------------------
# CONTEXT
# -------------------------
context = "\n\n".join(
[
f"Documento {i+1}:\n{doc.page_content}"
for i, doc in enumerate(top_docs)
]
)
# -------------------------
# FINAL PROMPT
# -------------------------
prompt = f"""<s>[INST] Eres un asistente experto en an谩lisis de documentos bancarios y contractuales.
INSTRUCCIONES:
- Responde SOLO usando el CONTEXTO proporcionado
- No inventes informaci贸n
- Si la informaci贸n no est谩 en el contexto, responde EXACTAMENTE:
"No tengo suficiente informaci贸n en los documentos disponibles para responder a esta consulta."
- Indica el documento utilizado (ej: "Seg煤n el Documento 1...")
- S茅 conciso y profesional
CONTEXTO:
{context}
PREGUNTA:
{question}
RESPUESTA: [/INST]"""
# -------------------------
# GENERATE ANSWER
# -------------------------
try:
response = self.llm.invoke(prompt)
# Clean up response
if response:
response = response.strip()
# Remove any instruction tags if present
response = response.replace("</s>", "").replace("<s>", "").strip()
return response if response else "No se pudo generar una respuesta."
except Exception as e:
logger.error(f"Error generando respuesta: {e}")
return f"Error al generar respuesta: {str(e)}"