P9 / app.py
magomerob's picture
Update app.py
d42fc52 verified
import os
import requests
import gradio as gr
import pandas as pd
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.document_loaders import PyPDFLoader
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from rerankers import Reranker
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
# ──────────────────────────────────────────────
# 1. Descargar y procesar el PDF
# ──────────────────────────────────────────────
PDF_URL = "https://escueladepacientes.es/images/Pdfs/Guia_Informativa_Diabetes_1.pdf"
PDF_PATH = "Guia_Informativa_Diabetes_1.pdf"
if not os.path.exists(PDF_PATH):
print("Descargando PDF...")
r = requests.get(PDF_URL)
with open(PDF_PATH, "wb") as f:
f.write(r.content)
print("Cargando documento...")
loader = PyPDFLoader(PDF_PATH)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
all_splits = text_splitter.split_documents(documents)
# ──────────────────────────────────────────────
# 2. Embeddings y base de datos vectorial
# ──────────────────────────────────────────────
print("Creando embeddings...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Usando dispositivo: {device}")
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
model_kwargs={"device": device}
)
vectordb = Chroma.from_documents(
documents=all_splits,
embedding=embeddings,
persist_directory="chroma_db"
)
print("Base de datos vectorial lista.")
# ──────────────────────────────────────────────
# 3. LLM: Qwen2.5-1.5B-Instruct (ligero, multilingüe, en español)
# ──────────────────────────────────────────────
print("Cargando LLM...")
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto"
)
hf_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=False,
)
llm = HuggingFacePipeline(pipeline=hf_pipeline)
print("LLM listo.")
# ──────────────────────────────────────────────
# 4. Reranker
# ──────────────────────────────────────────────
print("Cargando reranker...")
ranker = Reranker("answerdotai/answerai-colbert-small-v1", model_type="colbert")
print("Reranker listo.")
# ──────────────────────────────────────────────
# 5. Funciones RAG
# ──────────────────────────────────────────────
def construir_prompt_rag(context, question):
return (
f"Usando únicamente la siguiente información:\n\n{context}\n\n"
f"Responde en español a la pregunta: {question}\n\n"
f"Si la información no es suficiente, di: 'No tengo información para responder.'"
)
def rag_sin_reranking(query):
docs = vectordb.similarity_search_with_score(query)
context_parts = []
sources_parts = []
for doc, score in docs:
if score < 7:
content = doc.page_content
page = doc.metadata.get("page", "?")
context_parts.append(content)
sources_parts.append(f"📄 Página {page} (score: {score:.2f})\n{content[:250]}...")
if not context_parts:
return "No tengo información para responder a esta pregunta.", ""
context = "\n\n".join(context_parts)
prompt = construir_prompt_rag(context, query)
answer = llm.invoke(prompt)
sources = "\n\n---\n\n".join(sources_parts)
return answer, sources
def rag_con_reranking(query):
docs = vectordb.similarity_search_with_score(query)
context_parts = []
for doc, score in docs:
if score < 7:
context_parts.append(doc.page_content)
if not context_parts:
return "No tengo información para responder a esta pregunta.", ""
ranking = ranker.rank(query=query, docs=context_parts)
best_context = ranking[0].text
prompt = construir_prompt_rag(best_context, query)
answer = llm.invoke(prompt)
return answer, f"📄 Contexto seleccionado por reranking:\n\n{best_context}"
# ──────────────────────────────────────────────
# 6. Lógica del chat con parámetros dinámicos
# ──────────────────────────────────────────────
def actualizar_llm(temperature, top_k, top_p):
global llm, hf_pipeline
hf_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else None,
top_k=top_k if temperature > 0 else None,
top_p=top_p if temperature > 0 else None,
)
llm = HuggingFacePipeline(pipeline=hf_pipeline)
def chat(message, history, mode, temperature, top_k, top_p):
if not message.strip():
return history, history, ""
actualizar_llm(temperature, top_k, top_p)
if mode == "LLM base (sin RAG)":
answer = llm.invoke(message)
sources = ""
elif mode == "RAG sin reranking":
answer, sources = rag_sin_reranking(message)
else:
answer, sources = rag_con_reranking(message)
# Limpiar posibles repeticiones del prompt en la respuesta
if message in answer:
answer = answer.split(message)[-1].strip()
full_response = answer
if sources:
full_response += f"\n\n---\n**📚 Fuentes:**\n{sources}"
history.append((message, full_response))
return history, history, ""
# ──────────────────────────────────────────────
# 7. Interfaz Gradio
# ──────────────────────────────────────────────
with gr.Blocks(title="RAG - Guía de Diabetes", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🩺 Sistema de QA sobre Diabetes
Basado en la [Guía Informativa de Diabetes](https://escueladepacientes.es/mi-enfermedad/diabetes)
de la **Escuela de Pacientes**. Modelo: `Qwen2.5-1.5B-Instruct`.
""")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Conversación", height=500, bubble_full_width=False)
with gr.Row():
msg_input = gr.Textbox(
placeholder="Escribe tu pregunta aquí...",
label="Pregunta", scale=4, autofocus=True
)
send_btn = gr.Button("Enviar", variant="primary", scale=1)
clear_btn = gr.Button("🗑️ Limpiar conversación", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Configuración")
mode = gr.Radio(
choices=["LLM base (sin RAG)", "RAG sin reranking", "RAG con reranking"],
value="RAG con reranking",
label="Modo de respuesta"
)
gr.Markdown("### 🎛️ Parámetros")
temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Temperature")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
gr.Markdown("### 💡 Ejemplos")
gr.Examples(
examples=[
["¿Qué es la glucosa?"],
["¿Qué tratamiento tiene la diabetes tipo 1?"],
["¿Cuáles son los síntomas de la hipoglucemia?"],
["¿Qué diferencia hay entre diabetes tipo 1 y tipo 2?"],
["¿Cuál es la receta de la tarta de queso?"],
],
inputs=msg_input,
)
state = gr.State([])
send_btn.click(
fn=chat,
inputs=[msg_input, state, mode, temperature, top_k, top_p],
outputs=[chatbot, state, msg_input]
)
msg_input.submit(
fn=chat,
inputs=[msg_input, state, mode, temperature, top_k, top_p],
outputs=[chatbot, state, msg_input]
)
clear_btn.click(fn=lambda: ([], [], ""), outputs=[chatbot, state, msg_input])
demo.launch()