|
|
import re |
|
|
|
|
|
from typing import TypedDict, Annotated, List |
|
|
from typing_extensions import List, TypedDict |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
import chainlit as cl |
|
|
import operator |
|
|
|
|
|
from langchain.prompts import ChatPromptTemplate |
|
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_community.document_loaders import DirectoryLoader |
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
|
from langchain_core.documents import Document |
|
|
from langchain_core.messages import BaseMessage, HumanMessage |
|
|
from langchain_core.tools import tool |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_qdrant import QdrantVectorStore |
|
|
from langgraph.graph import START, StateGraph, END |
|
|
from langgraph.graph.message import add_messages |
|
|
from langgraph.prebuilt import ToolNode |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client.http.models import Distance, VectorParams |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
path = "data/" |
|
|
text_loader = DirectoryLoader(path, glob="*.pdf", loader_cls=PyPDFLoader) |
|
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size = 600, |
|
|
chunk_overlap = 200, |
|
|
length_function = len |
|
|
) |
|
|
|
|
|
|
|
|
def remove_references(doc): |
|
|
text = doc.page_content |
|
|
|
|
|
|
|
|
reference_markers = ["References", "Bibliography", "Cited Works", "Literature Cited"] |
|
|
|
|
|
for marker in reference_markers: |
|
|
if marker in text: |
|
|
text = text.split(marker)[0] |
|
|
break |
|
|
|
|
|
|
|
|
text = re.sub(r"https?://\S+|doi:\S+", "", text) |
|
|
text = re.sub(r"\[\d+\]", "", text) |
|
|
|
|
|
|
|
|
text = re.sub(r"\n{2,}", "\n", text).strip() |
|
|
doc.page_content = text.strip() |
|
|
return doc |
|
|
|
|
|
|
|
|
filtered_documents = [remove_references(doc) for doc in text_loader.load()] |
|
|
|
|
|
training_documents = text_splitter.split_documents(filtered_documents) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="Gonalb/flucold-ft-v2") |
|
|
|
|
|
|
|
|
client = QdrantClient(":memory:") |
|
|
|
|
|
client.create_collection( |
|
|
collection_name="ai_across_years", |
|
|
vectors_config=VectorParams(size=1024, distance=Distance.COSINE), |
|
|
) |
|
|
|
|
|
vector_store = QdrantVectorStore( |
|
|
client=client, |
|
|
collection_name="ai_across_years", |
|
|
embedding=embeddings, |
|
|
) |
|
|
|
|
|
_ = vector_store.add_documents(documents=training_documents) |
|
|
|
|
|
retriever = vector_store.as_retriever(search_kwargs={"k": 6}) |
|
|
|
|
|
class AgentState(TypedDict): |
|
|
messages: Annotated[list, "add_messages"] |
|
|
question: str |
|
|
context: List[Document] |
|
|
|
|
|
|
|
|
def retrieve(state): |
|
|
retrieved_docs = retriever.invoke(state["question"]) |
|
|
return {"context": retrieved_docs} |
|
|
|
|
|
RAG_PROMPT = """\ |
|
|
You are a helpful AI-powered Flu & Respiratory Illness Consultant. Your job is to help users determine whether they have the flu, a cold, RSV, or allergies based on their symptoms. |
|
|
Provide recommendations based on the context provided. If symptoms are severe, advise the user to seek medical attention. |
|
|
Avoid giving definitive diagnoses or prescriptions—always encourage users to consult a healthcare professional for serious cases. |
|
|
|
|
|
### Question |
|
|
{question} |
|
|
|
|
|
### Context |
|
|
{context} |
|
|
""" |
|
|
|
|
|
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) |
|
|
llm = ChatOpenAI(model="gpt-4o") |
|
|
|
|
|
def generate(state): |
|
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) |
|
|
messages = rag_prompt.format_messages(question=state["question"], context=docs_content) |
|
|
response = llm.invoke(messages) |
|
|
return {"messages": [response]} |
|
|
|
|
|
tavily_tool = TavilySearchResults(max_results=5) |
|
|
tool_belt = [tavily_tool] |
|
|
model = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(tool_belt) |
|
|
tool_node = ToolNode(tool_belt) |
|
|
|
|
|
def call_model(state): |
|
|
"""Llama al modelo base para generar respuestas.""" |
|
|
messages = state["messages"] |
|
|
response = model.invoke(messages) |
|
|
return { |
|
|
"messages": [response], |
|
|
"question": state["question"], |
|
|
"context": state.get("context", []) |
|
|
} |
|
|
|
|
|
|
|
|
uncompiled_graph = StateGraph(AgentState) |
|
|
|
|
|
uncompiled_graph.add_node("retrieve", retrieve) |
|
|
uncompiled_graph.add_node("generate", generate) |
|
|
uncompiled_graph.add_node("action", tool_node) |
|
|
|
|
|
uncompiled_graph.set_entry_point("retrieve") |
|
|
|
|
|
|
|
|
def should_continue(state): |
|
|
"""Decide si usar herramientas después de `generate`.""" |
|
|
last_message = state["messages"][-1] |
|
|
|
|
|
if last_message.tool_calls: |
|
|
return "action" |
|
|
|
|
|
return END |
|
|
|
|
|
|
|
|
uncompiled_graph.add_edge("retrieve", "generate") |
|
|
uncompiled_graph.add_conditional_edges("generate", should_continue) |
|
|
uncompiled_graph.add_edge("action", "generate") |
|
|
|
|
|
compiled_graph = uncompiled_graph.compile() |
|
|
|
|
|
|
|
|
@cl.on_chat_start |
|
|
async def start(): |
|
|
cl.user_session.set("graph", compiled_graph) |
|
|
cl.user_session.set("messages", []) |
|
|
@cl.on_message |
|
|
async def handle(message: cl.Message): |
|
|
graph = cl.user_session.get("graph") |
|
|
|
|
|
messages = cl.user_session.get("messages") |
|
|
messages.append(HumanMessage(content=message.content)) |
|
|
|
|
|
state = { |
|
|
"messages": messages, |
|
|
"question": message.content, |
|
|
"context": [] |
|
|
} |
|
|
|
|
|
response = await graph.ainvoke(state) |
|
|
|
|
|
cl.user_session.set("messages", state["messages"]) |
|
|
|
|
|
await cl.Message(content=response["messages"][-1].content).send() |