Flu_agent / app.py
Gonalb's picture
erase comments
42e6357
import re
from typing import TypedDict, Annotated, List
from typing_extensions import List, TypedDict
from dotenv import load_dotenv
import chainlit as cl
import operator
from langchain.prompts import ChatPromptTemplate
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore
from langgraph.graph import START, StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
load_dotenv()
path = "data/"
text_loader = DirectoryLoader(path, glob="*.pdf", loader_cls=PyPDFLoader)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 600,
chunk_overlap = 200,
length_function = len
)
def remove_references(doc):
text = doc.page_content
# Common headers for reference sections
reference_markers = ["References", "Bibliography", "Cited Works", "Literature Cited"]
for marker in reference_markers:
if marker in text:
text = text.split(marker)[0] # Keep only the content before references
break # Stop checking after the first match
# 2️⃣ Eliminar DOI, enlaces y citas tipo [1], [2], etc.
text = re.sub(r"https?://\S+|doi:\S+", "", text)
text = re.sub(r"\[\d+\]", "", text) # Remueve referencias numéricas en corchetes
# 3️⃣ Eliminar saltos de línea innecesarios
text = re.sub(r"\n{2,}", "\n", text).strip()
doc.page_content = text.strip() # Update document content
return doc
# Apply reference filtering
filtered_documents = [remove_references(doc) for doc in text_loader.load()]
training_documents = text_splitter.split_documents(filtered_documents)
embeddings = HuggingFaceEmbeddings(model_name="Gonalb/flucold-ft-v2")
client = QdrantClient(":memory:")
client.create_collection(
collection_name="ai_across_years",
vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
)
vector_store = QdrantVectorStore(
client=client,
collection_name="ai_across_years",
embedding=embeddings,
)
_ = vector_store.add_documents(documents=training_documents)
retriever = vector_store.as_retriever(search_kwargs={"k": 6})
class AgentState(TypedDict):
messages: Annotated[list, "add_messages"]
question: str
context: List[Document] # Para el RAG
# ----------------- RAG Components -----------------
def retrieve(state):
retrieved_docs = retriever.invoke(state["question"])
return {"context": retrieved_docs}
RAG_PROMPT = """\
You are a helpful AI-powered Flu & Respiratory Illness Consultant. Your job is to help users determine whether they have the flu, a cold, RSV, or allergies based on their symptoms.
Provide recommendations based on the context provided. If symptoms are severe, advise the user to seek medical attention.
Avoid giving definitive diagnoses or prescriptions—always encourage users to consult a healthcare professional for serious cases.
### Question
{question}
### Context
{context}
"""
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
llm = ChatOpenAI(model="gpt-4o")
def generate(state):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = rag_prompt.format_messages(question=state["question"], context=docs_content)
response = llm.invoke(messages)
return {"messages": [response]}
# ----------------- Tools & Agent -----------------
tavily_tool = TavilySearchResults(max_results=5)
tool_belt = [tavily_tool]
model = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(tool_belt)
tool_node = ToolNode(tool_belt)
def call_model(state):
"""Llama al modelo base para generar respuestas."""
messages = state["messages"]
response = model.invoke(messages)
return {
"messages": [response],
"question": state["question"],
"context": state.get("context", [])
}
# ----------------- Create graph -----------------
uncompiled_graph = StateGraph(AgentState)
uncompiled_graph.add_node("retrieve", retrieve)
uncompiled_graph.add_node("generate", generate)
uncompiled_graph.add_node("action", tool_node)
uncompiled_graph.set_entry_point("retrieve")
# ----------------- Logic -----------------
def should_continue(state):
"""Decide si usar herramientas después de `generate`."""
last_message = state["messages"][-1]
if last_message.tool_calls:
return "action"
return END
uncompiled_graph.add_edge("retrieve", "generate")
uncompiled_graph.add_conditional_edges("generate", should_continue)
uncompiled_graph.add_edge("action", "generate")
compiled_graph = uncompiled_graph.compile()
# ----------------- Chainlit Integration -----------------
@cl.on_chat_start
async def start():
cl.user_session.set("graph", compiled_graph)
cl.user_session.set("messages", [])
@cl.on_message
async def handle(message: cl.Message):
graph = cl.user_session.get("graph")
messages = cl.user_session.get("messages")
messages.append(HumanMessage(content=message.content))
state = {
"messages": messages,
"question": message.content,
"context": []
}
response = await graph.ainvoke(state)
cl.user_session.set("messages", state["messages"])
await cl.Message(content=response["messages"][-1].content).send()