graph
Browse files
app.py
CHANGED
|
@@ -87,98 +87,91 @@ _ = vector_store.add_documents(documents=training_documents)
|
|
| 87 |
|
| 88 |
retriever = vector_store.as_retriever(search_kwargs={"k": 6})
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def retrieve(state):
|
| 91 |
-
|
| 92 |
-
|
| 93 |
|
| 94 |
RAG_PROMPT = """\
|
| 95 |
You are a helpful AI-powered Flu & Respiratory Illness Consultant. Your job is to help users determine whether they have the flu, a cold, RSV, or allergies based on their symptoms.
|
| 96 |
Provide recommendations based on the context provided. If symptoms are severe, advise the user to seek medical attention.
|
| 97 |
Avoid giving definitive diagnoses or prescriptions—always encourage users to consult a healthcare professional for serious cases.
|
|
|
|
| 98 |
### Question
|
| 99 |
{question}
|
|
|
|
| 100 |
### Context
|
| 101 |
{context}
|
| 102 |
"""
|
| 103 |
|
| 104 |
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
|
| 105 |
-
|
| 106 |
llm = ChatOpenAI(model="gpt-4o")
|
| 107 |
|
| 108 |
def generate(state):
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
from langgraph.graph import START, StateGraph
|
| 115 |
-
from typing_extensions import List, TypedDict
|
| 116 |
-
from langchain_core.documents import Document
|
| 117 |
-
|
| 118 |
-
class State(TypedDict):
|
| 119 |
-
question: str
|
| 120 |
-
context: List[Document]
|
| 121 |
-
response: str
|
| 122 |
-
|
| 123 |
-
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
|
| 124 |
-
graph_builder.add_edge(START, "retrieve")
|
| 125 |
-
graph = graph_builder.compile()
|
| 126 |
-
|
| 127 |
tavily_tool = TavilySearchResults(max_results=5)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
tavily_tool
|
| 131 |
-
]
|
| 132 |
-
|
| 133 |
-
model = ChatOpenAI(model="gpt-4o", temperature=0)
|
| 134 |
-
model = model.bind_tools(tool_belt)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class AgentState(TypedDict):
|
| 138 |
-
messages: Annotated[list, add_messages]
|
| 139 |
-
context: List[Document]
|
| 140 |
-
|
| 141 |
tool_node = ToolNode(tool_belt)
|
| 142 |
|
| 143 |
-
uncompiled_graph = StateGraph(AgentState)
|
| 144 |
-
|
| 145 |
def call_model(state):
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
| 149 |
"messages": [response],
|
|
|
|
| 150 |
"context": state.get("context", [])
|
| 151 |
}
|
| 152 |
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
uncompiled_graph.add_node("action", tool_node)
|
| 155 |
|
| 156 |
-
uncompiled_graph.set_entry_point("
|
| 157 |
|
|
|
|
| 158 |
def should_continue(state):
|
| 159 |
-
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
|
| 164 |
-
|
| 165 |
|
| 166 |
-
uncompiled_graph.add_conditional_edges(
|
| 167 |
-
"agent",
|
| 168 |
-
should_continue
|
| 169 |
-
)
|
| 170 |
|
| 171 |
-
uncompiled_graph.add_edge("
|
|
|
|
|
|
|
| 172 |
|
| 173 |
compiled_graph = uncompiled_graph.compile()
|
| 174 |
|
|
|
|
| 175 |
@cl.on_chat_start
|
| 176 |
async def start():
|
| 177 |
-
|
| 178 |
|
| 179 |
@cl.on_message
|
| 180 |
async def handle(message: cl.Message):
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
retriever = vector_store.as_retriever(search_kwargs={"k": 6})
|
| 89 |
|
| 90 |
+
class AgentState(TypedDict):
|
| 91 |
+
messages: Annotated[list, "add_messages"]
|
| 92 |
+
question: str
|
| 93 |
+
context: List[Document] # Para el RAG
|
| 94 |
+
|
| 95 |
+
# ----------------- RAG Components -----------------
|
| 96 |
def retrieve(state):
|
| 97 |
+
retrieved_docs = retriever.invoke(state["question"])
|
| 98 |
+
return {"context": retrieved_docs}
|
| 99 |
|
| 100 |
RAG_PROMPT = """\
|
| 101 |
You are a helpful AI-powered Flu & Respiratory Illness Consultant. Your job is to help users determine whether they have the flu, a cold, RSV, or allergies based on their symptoms.
|
| 102 |
Provide recommendations based on the context provided. If symptoms are severe, advise the user to seek medical attention.
|
| 103 |
Avoid giving definitive diagnoses or prescriptions—always encourage users to consult a healthcare professional for serious cases.
|
| 104 |
+
|
| 105 |
### Question
|
| 106 |
{question}
|
| 107 |
+
|
| 108 |
### Context
|
| 109 |
{context}
|
| 110 |
"""
|
| 111 |
|
| 112 |
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
|
|
|
|
| 113 |
llm = ChatOpenAI(model="gpt-4o")
|
| 114 |
|
| 115 |
def generate(state):
|
| 116 |
+
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
| 117 |
+
messages = rag_prompt.format_messages(question=state["question"], context=docs_content)
|
| 118 |
+
response = llm.invoke(messages)
|
| 119 |
+
return {"messages": [response]}
|
| 120 |
+
# ----------------- Tools & Agent -----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
tavily_tool = TavilySearchResults(max_results=5)
|
| 122 |
+
tool_belt = [tavily_tool]
|
| 123 |
+
model = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(tool_belt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
tool_node = ToolNode(tool_belt)
|
| 125 |
|
|
|
|
|
|
|
| 126 |
def call_model(state):
|
| 127 |
+
"""Llama al modelo base para generar respuestas."""
|
| 128 |
+
messages = state["messages"]
|
| 129 |
+
response = model.invoke(messages)
|
| 130 |
+
return {
|
| 131 |
"messages": [response],
|
| 132 |
+
"question": state["question"],
|
| 133 |
"context": state.get("context", [])
|
| 134 |
}
|
| 135 |
|
| 136 |
+
# ----------------- Create graph -----------------
|
| 137 |
+
uncompiled_graph = StateGraph(AgentState)
|
| 138 |
+
|
| 139 |
+
uncompiled_graph.add_node("retrieve", retrieve)
|
| 140 |
+
uncompiled_graph.add_node("generate", generate)
|
| 141 |
uncompiled_graph.add_node("action", tool_node)
|
| 142 |
|
| 143 |
+
uncompiled_graph.set_entry_point("retrieve")
|
| 144 |
|
| 145 |
+
# ----------------- Logic -----------------
|
| 146 |
def should_continue(state):
|
| 147 |
+
"""Decide si usar herramientas después de `generate`."""
|
| 148 |
+
last_message = state["messages"][-1]
|
| 149 |
|
| 150 |
+
if last_message.tool_calls:
|
| 151 |
+
return "action"
|
| 152 |
|
| 153 |
+
return END
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
uncompiled_graph.add_edge("retrieve", "generate")
|
| 157 |
+
uncompiled_graph.add_conditional_edges("generate", should_continue)
|
| 158 |
+
uncompiled_graph.add_edge("action", "generate")
|
| 159 |
|
| 160 |
compiled_graph = uncompiled_graph.compile()
|
| 161 |
|
| 162 |
+
# ----------------- Chainlit Integration -----------------
|
| 163 |
@cl.on_chat_start
|
| 164 |
async def start():
|
| 165 |
+
cl.user_session.set("graph", compiled_graph)
|
| 166 |
|
| 167 |
@cl.on_message
|
| 168 |
async def handle(message: cl.Message):
|
| 169 |
+
graph = cl.user_session.get("graph")
|
| 170 |
+
state = {
|
| 171 |
+
"messages": [HumanMessage(content=message.content)],
|
| 172 |
+
"question": message.content,
|
| 173 |
+
"context": []
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
response = await graph.ainvoke(state)
|
| 177 |
+
await cl.Message(content=response["messages"][-1].content).send()
|