Gonalb commited on
Commit
8327038
·
1 Parent(s): 3296d11
Files changed (1) hide show
  1. app.py +48 -55
app.py CHANGED
@@ -87,98 +87,91 @@ _ = vector_store.add_documents(documents=training_documents)
87
 
88
  retriever = vector_store.as_retriever(search_kwargs={"k": 6})
89
 
 
 
 
 
 
 
90
  def retrieve(state):
91
- retrieved_docs = retriever.invoke(state["question"])
92
- return {"context" : retrieved_docs}
93
 
94
  RAG_PROMPT = """\
95
  You are a helpful AI-powered Flu & Respiratory Illness Consultant. Your job is to help users determine whether they have the flu, a cold, RSV, or allergies based on their symptoms.
96
  Provide recommendations based on the context provided. If symptoms are severe, advise the user to seek medical attention.
97
  Avoid giving definitive diagnoses or prescriptions—always encourage users to consult a healthcare professional for serious cases.
 
98
  ### Question
99
  {question}
 
100
  ### Context
101
  {context}
102
  """
103
 
104
  rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
105
-
106
  llm = ChatOpenAI(model="gpt-4o")
107
 
108
  def generate(state):
109
- docs_content = "\n\n".join(doc.page_content for doc in state["context"])
110
- messages = rag_prompt.format_messages(question=state["question"], context=docs_content)
111
- response = llm.invoke(messages)
112
- return {"response" : response.content}
113
-
114
- from langgraph.graph import START, StateGraph
115
- from typing_extensions import List, TypedDict
116
- from langchain_core.documents import Document
117
-
118
- class State(TypedDict):
119
- question: str
120
- context: List[Document]
121
- response: str
122
-
123
- graph_builder = StateGraph(State).add_sequence([retrieve, generate])
124
- graph_builder.add_edge(START, "retrieve")
125
- graph = graph_builder.compile()
126
-
127
  tavily_tool = TavilySearchResults(max_results=5)
128
-
129
- tool_belt = [
130
- tavily_tool
131
- ]
132
-
133
- model = ChatOpenAI(model="gpt-4o", temperature=0)
134
- model = model.bind_tools(tool_belt)
135
-
136
-
137
- class AgentState(TypedDict):
138
- messages: Annotated[list, add_messages]
139
- context: List[Document]
140
-
141
  tool_node = ToolNode(tool_belt)
142
 
143
- uncompiled_graph = StateGraph(AgentState)
144
-
145
  def call_model(state):
146
- messages = state["messages"]
147
- response = model.invoke(messages)
148
- return {
 
149
  "messages": [response],
 
150
  "context": state.get("context", [])
151
  }
152
 
153
- uncompiled_graph.add_node("agent", call_model)
 
 
 
 
154
  uncompiled_graph.add_node("action", tool_node)
155
 
156
- uncompiled_graph.set_entry_point("agent")
157
 
 
158
  def should_continue(state):
159
- last_message = state["messages"][-1]
 
160
 
161
- if last_message.tool_calls:
162
- return "action"
163
 
164
- return END
165
 
166
- uncompiled_graph.add_conditional_edges(
167
- "agent",
168
- should_continue
169
- )
170
 
171
- uncompiled_graph.add_edge("action", "agent")
 
 
172
 
173
  compiled_graph = uncompiled_graph.compile()
174
 
 
175
  @cl.on_chat_start
176
  async def start():
177
- cl.user_session.set("graph", compiled_graph)
178
 
179
  @cl.on_message
180
  async def handle(message: cl.Message):
181
- graph = cl.user_session.get("graph")
182
- state = {"messages" : [HumanMessage(content=message.content)]}
183
- response = await graph.ainvoke(state)
184
- await cl.Message(content=response["messages"][-1].content).send()
 
 
 
 
 
 
87
 
88
  retriever = vector_store.as_retriever(search_kwargs={"k": 6})
89
 
90
+ class AgentState(TypedDict):
91
+ messages: Annotated[list, "add_messages"]
92
+ question: str
93
+ context: List[Document] # Para el RAG
94
+
95
+ # ----------------- RAG Components -----------------
96
  def retrieve(state):
97
+ retrieved_docs = retriever.invoke(state["question"])
98
+ return {"context": retrieved_docs}
99
 
100
  RAG_PROMPT = """\
101
  You are a helpful AI-powered Flu & Respiratory Illness Consultant. Your job is to help users determine whether they have the flu, a cold, RSV, or allergies based on their symptoms.
102
  Provide recommendations based on the context provided. If symptoms are severe, advise the user to seek medical attention.
103
  Avoid giving definitive diagnoses or prescriptions—always encourage users to consult a healthcare professional for serious cases.
104
+
105
  ### Question
106
  {question}
107
+
108
  ### Context
109
  {context}
110
  """
111
 
112
  rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
 
113
  llm = ChatOpenAI(model="gpt-4o")
114
 
115
  def generate(state):
116
+ docs_content = "\n\n".join(doc.page_content for doc in state["context"])
117
+ messages = rag_prompt.format_messages(question=state["question"], context=docs_content)
118
+ response = llm.invoke(messages)
119
+ return {"messages": [response]}
120
+ # ----------------- Tools & Agent -----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  tavily_tool = TavilySearchResults(max_results=5)
122
+ tool_belt = [tavily_tool]
123
+ model = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(tool_belt)
 
 
 
 
 
 
 
 
 
 
 
124
  tool_node = ToolNode(tool_belt)
125
 
 
 
126
  def call_model(state):
127
+ """Llama al modelo base para generar respuestas."""
128
+ messages = state["messages"]
129
+ response = model.invoke(messages)
130
+ return {
131
  "messages": [response],
132
+ "question": state["question"],
133
  "context": state.get("context", [])
134
  }
135
 
136
+ # ----------------- Create graph -----------------
137
+ uncompiled_graph = StateGraph(AgentState)
138
+
139
+ uncompiled_graph.add_node("retrieve", retrieve)
140
+ uncompiled_graph.add_node("generate", generate)
141
  uncompiled_graph.add_node("action", tool_node)
142
 
143
+ uncompiled_graph.set_entry_point("retrieve")
144
 
145
+ # ----------------- Logic -----------------
146
  def should_continue(state):
147
+ """Decide si usar herramientas después de `generate`."""
148
+ last_message = state["messages"][-1]
149
 
150
+ if last_message.tool_calls:
151
+ return "action"
152
 
153
+ return END
154
 
 
 
 
 
155
 
156
+ uncompiled_graph.add_edge("retrieve", "generate")
157
+ uncompiled_graph.add_conditional_edges("generate", should_continue)
158
+ uncompiled_graph.add_edge("action", "generate")
159
 
160
  compiled_graph = uncompiled_graph.compile()
161
 
162
+ # ----------------- Chainlit Integration -----------------
163
  @cl.on_chat_start
164
  async def start():
165
+ cl.user_session.set("graph", compiled_graph)
166
 
167
  @cl.on_message
168
  async def handle(message: cl.Message):
169
+ graph = cl.user_session.get("graph")
170
+ state = {
171
+ "messages": [HumanMessage(content=message.content)],
172
+ "question": message.content,
173
+ "context": []
174
+ }
175
+
176
+ response = await graph.ainvoke(state)
177
+ await cl.Message(content=response["messages"][-1].content).send()