import re from typing import TypedDict, Annotated, List from typing_extensions import List, TypedDict from dotenv import load_dotenv import chainlit as cl import operator from langchain.prompts import ChatPromptTemplate from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import DirectoryLoader from langchain_community.document_loaders import PyPDFLoader from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.documents import Document from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.tools import tool from langchain_openai import ChatOpenAI from langchain_huggingface import HuggingFaceEmbeddings from langchain_qdrant import QdrantVectorStore from langgraph.graph import START, StateGraph, END from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams load_dotenv() path = "data/" text_loader = DirectoryLoader(path, glob="*.pdf", loader_cls=PyPDFLoader) text_splitter = RecursiveCharacterTextSplitter( chunk_size = 600, chunk_overlap = 200, length_function = len ) def remove_references(doc): text = doc.page_content # Common headers for reference sections reference_markers = ["References", "Bibliography", "Cited Works", "Literature Cited"] for marker in reference_markers: if marker in text: text = text.split(marker)[0] # Keep only the content before references break # Stop checking after the first match # 2️⃣ Eliminar DOI, enlaces y citas tipo [1], [2], etc. text = re.sub(r"https?://\S+|doi:\S+", "", text) text = re.sub(r"\[\d+\]", "", text) # Remueve referencias numéricas en corchetes # 3️⃣ Eliminar saltos de línea innecesarios text = re.sub(r"\n{2,}", "\n", text).strip() doc.page_content = text.strip() # Update document content return doc # Apply reference filtering filtered_documents = [remove_references(doc) for doc in text_loader.load()] training_documents = text_splitter.split_documents(filtered_documents) embeddings = HuggingFaceEmbeddings(model_name="Gonalb/flucold-ft-v2") client = QdrantClient(":memory:") client.create_collection( collection_name="ai_across_years", vectors_config=VectorParams(size=1024, distance=Distance.COSINE), ) vector_store = QdrantVectorStore( client=client, collection_name="ai_across_years", embedding=embeddings, ) _ = vector_store.add_documents(documents=training_documents) retriever = vector_store.as_retriever(search_kwargs={"k": 6}) class AgentState(TypedDict): messages: Annotated[list, "add_messages"] question: str context: List[Document] # Para el RAG # ----------------- RAG Components ----------------- def retrieve(state): retrieved_docs = retriever.invoke(state["question"]) return {"context": retrieved_docs} RAG_PROMPT = """\ You are a helpful AI-powered Flu & Respiratory Illness Consultant. Your job is to help users determine whether they have the flu, a cold, RSV, or allergies based on their symptoms. Provide recommendations based on the context provided. If symptoms are severe, advise the user to seek medical attention. Avoid giving definitive diagnoses or prescriptions—always encourage users to consult a healthcare professional for serious cases. ### Question {question} ### Context {context} """ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) llm = ChatOpenAI(model="gpt-4o") def generate(state): docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = rag_prompt.format_messages(question=state["question"], context=docs_content) response = llm.invoke(messages) return {"messages": [response]} # ----------------- Tools & Agent ----------------- tavily_tool = TavilySearchResults(max_results=5) tool_belt = [tavily_tool] model = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(tool_belt) tool_node = ToolNode(tool_belt) def call_model(state): """Llama al modelo base para generar respuestas.""" messages = state["messages"] response = model.invoke(messages) return { "messages": [response], "question": state["question"], "context": state.get("context", []) } # ----------------- Create graph ----------------- uncompiled_graph = StateGraph(AgentState) uncompiled_graph.add_node("retrieve", retrieve) uncompiled_graph.add_node("generate", generate) uncompiled_graph.add_node("action", tool_node) uncompiled_graph.set_entry_point("retrieve") # ----------------- Logic ----------------- def should_continue(state): """Decide si usar herramientas después de `generate`.""" last_message = state["messages"][-1] if last_message.tool_calls: return "action" return END uncompiled_graph.add_edge("retrieve", "generate") uncompiled_graph.add_conditional_edges("generate", should_continue) uncompiled_graph.add_edge("action", "generate") compiled_graph = uncompiled_graph.compile() # ----------------- Chainlit Integration ----------------- @cl.on_chat_start async def start(): cl.user_session.set("graph", compiled_graph) cl.user_session.set("messages", []) @cl.on_message async def handle(message: cl.Message): graph = cl.user_session.get("graph") messages = cl.user_session.get("messages") messages.append(HumanMessage(content=message.content)) state = { "messages": messages, "question": message.content, "context": [] } response = await graph.ainvoke(state) cl.user_session.set("messages", state["messages"]) await cl.Message(content=response["messages"][-1].content).send()