chatAPI / rag_api.py
jimytech's picture
Update rag_api.py
ce734ca verified
raw
history blame
5.82 kB
import os
import requests
import shutil
from langchain_community.vectorstores import FAISS
from fastapi import FastAPI
from pydantic import BaseModel
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
# --------------------------------------------------------
# CACHÉ EN /tmp
# --------------------------------------------------------
TEMP_CACHE_DIR = '/tmp/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = TEMP_CACHE_DIR
os.environ['HF_HOME'] = TEMP_CACHE_DIR
os.environ['SENTENCE_TRANSFORMERS_HOME'] = TEMP_CACHE_DIR
os.makedirs(TEMP_CACHE_DIR, exist_ok=True)
# --------------------------------------------------------
# 1. CONFIGURACIÓN Y PROMPTS
# --------------------------------------------------------
URL_FAISS = "https://drive.google.com/uc?export=download&id=1hiVycS4DQHO1MBdC-L_z1TXA6sJO_Y-r"
URL_PKL = "https://drive.google.com/uc?export=download&id=1vbG8unx88Kb5jn7puGv1gqSM4S6rIUQC"
DOWNLOAD_DIR = "/tmp/db_faiss"
DB_FAISS_PATH = DOWNLOAD_DIR
# --- NUEVO: PROMPT PARA RE-ESCRIBIR LA PREGUNTA ---
CONDENSE_PROMPT = PromptTemplate(
template="""Dada la siguiente conversación y una pregunta de seguimiento, reescribe la pregunta de seguimiento para que sea una pregunta independiente que contenga todo el contexto, especialmente si se refiere a la UPT Aragua.
Historial:
{chat_history}
Pregunta de seguimiento: {question}
Pregunta independiente reescrita:""",
input_variables=["chat_history", "question"]
)
INTENT_PROMPT = PromptTemplate(
template="""Eres un clasificador de intenciones para la UPT Aragua. Clasifica en: SALUDO, UNIVERSIDAD u OTRO.
Responde SOLO con la categoría.
Mensaje: {query}
Categoría:""",
input_variables=["query"]
)
SALUDO_PROMPT = PromptTemplate(
template="""Eres UPTA bot, saluda amigablemente y menciona que puedes ayudar con info de la UPT Aragua.
Mensaje: {query}
Respuesta:""",
input_variables=["query"]
)
RAG_PROMPT = PromptTemplate(
template="""Eres UPTA bot, experto de la UPT Aragua. Responde usando el contexto. Si no lo sabes, pide ser más específico.
Contexto: {context}
Pregunta: {question}
Respuesta:""",
input_variables=["context", "question"]
)
# --------------------------------------------------------
# 2. MODELOS DE DATOS
# --------------------------------------------------------
class QueryRequest(BaseModel):
query: str
history: list = [] # Aquí recibiremos el historial desde Gradio
# --------------------------------------------------------
# 3. FUNCIONES DE CARGA
# --------------------------------------------------------
def download_file(url, local_path):
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, stream=True, headers=headers, timeout=30)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, 'wb') as f:
shutil.copyfileobj(response.raw, f)
def load_and_configure_rag():
download_file(URL_FAISS, os.path.join(DOWNLOAD_DIR, 'index.faiss'))
download_file(URL_PKL, os.path.join(DOWNLOAD_DIR, 'index.pkl'))
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'},
cache_folder=TEMP_CACHE_DIR
)
vectorstore = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
# Asegúrate de tener la variable de entorno GROQ_API_KEY configurada en Hugging Face
llm = ChatGroq(temperature=0.15, model_name="openai/gpt-oss-120b")
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Creamos todas las cadenas
condense_chain = CONDENSE_PROMPT | llm
intent_chain = INTENT_PROMPT | llm
saludo_chain = SALUDO_PROMPT | llm
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| RAG_PROMPT
| llm
)
return condense_chain, intent_chain, saludo_chain, rag_chain, retriever
# --------------------------------------------------------
# 4. API FASTAPI
# --------------------------------------------------------
app = FastAPI()
condense_chain = intent_chain = saludo_chain = rag_chain = retriever = None
@app.on_event("startup")
async def startup_event():
global condense_chain, intent_chain, saludo_chain, rag_chain, retriever
condense_chain, intent_chain, saludo_chain, rag_chain, retriever = load_and_configure_rag()
@app.post("/query")
async def process_query(request: QueryRequest):
# 1. Convertir historial a texto
chat_str = ""
for user_msg, bot_msg in request.history:
chat_str += f"Usuario: {user_msg}\nBot: {bot_msg}\n"
# 2. Re-escribir consulta si hay historial
query_to_process = request.query
if request.history:
res = condense_chain.invoke({"chat_history": chat_str, "question": request.query})
query_to_process = res.content.strip()
# 3. Clasificar intención
intent_res = intent_chain.invoke({"query": query_to_process})
intent = intent_res.content.upper()
if "SALUDO" in intent:
resp = saludo_chain.invoke({"query": request.query})
return {"response": resp.content, "intent": "SALUDO"}
elif "OTRO" in intent:
return {"response": "Solo puedo ayudarte con temas de la UPT Aragua.", "intent": "OTRO"}
else:
# RAG con la consulta re-escrita
resp = rag_chain.invoke(query_to_process)
docs = retriever.invoke(query_to_process)
sources = list(set([doc.metadata.get("source", "N/A") for doc in docs]))
return {"response": resp.content, "intent": "UNIVERSIDAD", "sources": sources}
except Exception as e:
return {"error": f"Error al procesar la consulta: {e}"}