Adamsyara commited on
Commit
1f07922
·
verified ·
1 Parent(s): 9e56189

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +428 -0
app.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import os
4
+ import chromadb
5
+ from dotenv import load_dotenv
6
+ import json
7
+ import numpy as np
8
+ from groq import Groq
9
+ from mem0 import MemoryClient
10
+ import streamlit as st
11
+ from datetime import datetime
12
+ from typing import Dict, List, Tuple, Any, TypedDict
13
+
14
+ # LangChain imports
15
+ from langchain_core.documents import Document
16
+ from langchain_core.runnables import RunnablePassthrough
17
+ from langchain_core.output_parsers import StrOutputParser
18
+ from langchain.prompts import ChatPromptTemplate
19
+ from langchain.chains.query_constructor.base import AttributeInfo
20
+ from langchain.retrievers.self_query.base import SelfQueryRetriever
21
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker
22
+ from langchain.retrievers import ContextualCompressionRetriever
23
+ from langchain_community.vectorstores import Chroma
24
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader
25
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
26
+ from langchain_experimental.text_splitter import SemanticChunker
27
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
28
+ from langchain_core.tools import tool
29
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
30
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
31
+ from llama_index.core import Settings
32
+ from langgraph.graph import StateGraph, END, START
33
+ from pydantic import BaseModel
34
+
35
+ #====================================SETUP=====================================#
36
+ # Fetch secrets from Hugging Face Spaces
37
+ api_key = os.environ.get("API_KEY")
38
+ endpoint = os.environ.get("OPENAI_API_BASE")
39
+ llama_api_key = os.environ.get('GROQ_API_KEY')
40
+ MEM0_api_key = os.environ.get('mem0')
41
+
42
+ # Initialize the OpenAI Embeddings
43
+ embedding_model = OpenAIEmbeddings(
44
+ openai_api_base=endpoint,
45
+ openai_api_key=api_key,
46
+ model='text-embedding-ada-002'
47
+ )
48
+
49
+ # Initialize the Chat OpenAI model
50
+ llm = ChatOpenAI(
51
+ base_url=endpoint,
52
+ openai_api_key=api_key,
53
+ model="gpt-4o-mini",
54
+ streaming=False
55
+ )
56
+
57
+ # set the LLM and embedding model in the LlamaIndex settings.
58
+ Settings.llm = llm
59
+ Settings.embedding = embedding_model
60
+
61
+ #================================Creating Langgraph agent======================#
62
+
63
+ class AgentState(TypedDict):
64
+ query: str
65
+ expanded_query: str
66
+ context: List[Dict[str, Any]]
67
+ response: str
68
+ precision_score: float
69
+ groundedness_score: float
70
+ groundedness_loop_count: int
71
+ precision_loop_count: int
72
+ feedback: str
73
+ query_feedback: str
74
+ groundedness_check: bool
75
+ loop_max_iter: int
76
+
77
+ def expand_query(state):
78
+ print("---------Expanding Query---------")
79
+ system_message = '''You are a query expansion expert for nutrition and health topics.
80
+ Expand the given query to improve information retrieval by adding relevant terms, synonyms, and related concepts.
81
+ Focus on nutrition disorders, dietary conditions, and health topics. Return only the expanded query.'''
82
+ expand_prompt = ChatPromptTemplate.from_messages([
83
+ ("system", system_message),
84
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
85
+ ])
86
+ chain = expand_prompt | llm | StrOutputParser()
87
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback": state.get("query_feedback", "Improve retrieval effectiveness")})
88
+ print("expanded_query", expanded_query)
89
+ state["expanded_query"] = expanded_query
90
+ return state
91
+
92
+ # Initialize Vector Store
93
+ vector_store = Chroma(
94
+ collection_name="nutritional_hypotheticals",
95
+ persist_directory="./nutritional_db",
96
+ embedding_function=embedding_model
97
+ )
98
+ retriever = vector_store.as_retriever(search_type='similarity', search_kwargs={'k': 3})
99
+
100
+ def retrieve_context(state):
101
+ print("---------retrieve_context---------")
102
+ query = state['expanded_query']
103
+ docs = retriever.invoke(query)
104
+ context = [{"content": doc.page_content, "metadata": doc.metadata} for doc in docs]
105
+ state['context'] = context
106
+ return state
107
+
108
+ def craft_response(state: Dict) -> Dict:
109
+ print("---------craft_response---------")
110
+ system_message = '''You are an expert nutrition and health advisor. Provide accurate, evidence-based responses about nutrition disorders and dietary conditions.
111
+
112
+ Guidelines:
113
+ - Use only information from the provided context
114
+ - Give clear, actionable advice when appropriate
115
+ - Maintain a professional yet accessible tone
116
+ - If context is insufficient, acknowledge limitations
117
+ - Recommend professional consultation when appropriate
118
+
119
+ Generate a comprehensive response based strictly on the provided context.'''
120
+ response_prompt = ChatPromptTemplate.from_messages([
121
+ ("system", system_message),
122
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
123
+ ])
124
+ chain = response_prompt | llm | StrOutputParser()
125
+ response = chain.invoke({
126
+ "query": state['query'],
127
+ "context": "\n".join([doc["content"] for doc in state['context']]),
128
+ "feedback": state.get('feedback', "Provide a helpful and accurate response")
129
+ })
130
+ state['response'] = response
131
+ return state
132
+
133
+ def score_groundedness(state: Dict) -> Dict:
134
+ print("---------check_groundedness---------")
135
+ system_message = '''You are an expert evaluator. Rate how well the response is grounded in the provided context.
136
+
137
+ Scale:
138
+ - 1.0 = Fully grounded (all information comes from context)
139
+ - 0.8 = Mostly grounded (minor reasonable inferences)
140
+ - 0.6 = Partially grounded (some claims supported)
141
+ - 0.4 = Weakly grounded (few claims supported)
142
+ - 0.2 = Poorly grounded (mostly unsupported)
143
+ - 0.0 = Not grounded (contradicts or ignores context)
144
+
145
+ Return ONLY a decimal number between 0.0 and 1.0.'''
146
+ groundedness_prompt = ChatPromptTemplate.from_messages([
147
+ ("system", system_message),
148
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
149
+ ])
150
+ chain = groundedness_prompt | llm | StrOutputParser()
151
+ try:
152
+ score_str = chain.invoke({
153
+ "context": "\n".join([doc["content"] for doc in state['context']]),
154
+ "response": state['response']
155
+ })
156
+ import re
157
+ match = re.search(r"\d+(\.\d+)?", score_str)
158
+ groundedness_score = float(match.group(0)) if match else 0.0
159
+ except:
160
+ groundedness_score = 0.0
161
+ state['groundedness_loop_count'] += 1
162
+ state['groundedness_score'] = groundedness_score
163
+ return state
164
+
165
+ def check_precision(state: Dict) -> Dict:
166
+ print("---------check_precision---------")
167
+ system_message = '''You are an expert evaluator. Rate how precisely the response addresses the user's query on a scale of 0.0 to 1.0.
168
+
169
+ Consider:
170
+ - Does the response directly answer what was asked?
171
+ - Are all parts of the query addressed?
172
+ - Is there unnecessary or irrelevant information?
173
+ - Is the response focused and on-topic?
174
+
175
+ Return ONLY a decimal number between 0.0 and 1.0.'''
176
+ precision_prompt = ChatPromptTemplate.from_messages([
177
+ ("system", system_message),
178
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
179
+ ])
180
+ chain = precision_prompt | llm | StrOutputParser()
181
+ try:
182
+ score_str = chain.invoke({
183
+ "query": state['query'],
184
+ "response": state['response']
185
+ })
186
+ import re
187
+ match = re.search(r"\d+(\.\d+)?", score_str)
188
+ precision_score = float(match.group(0)) if match else 0.0
189
+ except:
190
+ precision_score = 0.0
191
+ state['precision_score'] = precision_score
192
+ state['precision_loop_count'] += 1
193
+ return state
194
+
195
+ def refine_response(state: Dict) -> Dict:
196
+ print("---------refine_response---------")
197
+ system_message = '''You are an expert reviewer. Analyze the response and suggest specific improvements for better accuracy, completeness, and clarity. Focus on actionable recommendations.'''
198
+ refine_response_prompt = ChatPromptTemplate.from_messages([
199
+ ("system", system_message),
200
+ ("user", "Query: {query}\nResponse: {response}\n\nWhat improvements can be made?")
201
+ ])
202
+ chain = refine_response_prompt | llm | StrOutputParser()
203
+ feedback = chain.invoke({'query': state['query'], 'response': state['response']})
204
+ state['feedback'] = feedback
205
+ return state
206
+
207
+ def refine_query(state: Dict) -> Dict:
208
+ print("---------refine_query---------")
209
+ system_message = '''You are a query optimization expert. Analyze the expanded query and suggest specific improvements to enhance information retrieval effectiveness. Focus on terminology, specificity, and comprehensiveness.'''
210
+ refine_query_prompt = ChatPromptTemplate.from_messages([
211
+ ("system", system_message),
212
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\nWhat improvements can be made?")
213
+ ])
214
+ chain = refine_query_prompt | llm | StrOutputParser()
215
+ query_feedback = chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})
216
+ state['query_feedback'] = query_feedback
217
+ return state
218
+
219
+ def should_continue_groundedness(state):
220
+ if state['groundedness_score'] >= 0.8:
221
+ return "check_precision"
222
+ elif state["groundedness_loop_count"] > state['loop_max_iter']:
223
+ return "max_iterations_reached"
224
+ else:
225
+ return "refine_response"
226
+
227
+ def should_continue_precision(state: Dict) -> str:
228
+ if state['precision_score'] >= 0.8:
229
+ return "pass"
230
+ elif state["precision_loop_count"] > state['loop_max_iter']:
231
+ return "max_iterations_reached"
232
+ else:
233
+ return "refine_query"
234
+
235
+ def max_iterations_reached(state: Dict) -> Dict:
236
+ state['response'] = "I'm unable to refine the response further. Please provide more context or clarify your question."
237
+ return state
238
+
239
+ def create_workflow() -> StateGraph:
240
+ workflow = StateGraph(AgentState)
241
+ workflow.add_node("expand_query", expand_query)
242
+ workflow.add_node("retrieve_context", retrieve_context)
243
+ workflow.add_node("craft_response", craft_response)
244
+ workflow.add_node("score_groundedness", score_groundedness)
245
+ workflow.add_node("refine_response", refine_response)
246
+ workflow.add_node("check_precision", check_precision)
247
+ workflow.add_node("refine_query", refine_query)
248
+ workflow.add_node("max_iterations_reached", max_iterations_reached)
249
+
250
+ workflow.add_edge(START, "expand_query")
251
+ workflow.add_edge("expand_query", "retrieve_context")
252
+ workflow.add_edge("retrieve_context", "craft_response")
253
+ workflow.add_edge("craft_response", "score_groundedness")
254
+
255
+ workflow.add_conditional_edges(
256
+ "score_groundedness",
257
+ should_continue_groundedness,
258
+ {
259
+ "check_precision": "check_precision",
260
+ "refine_response": "refine_response",
261
+ "max_iterations_reached": "max_iterations_reached"
262
+ }
263
+ )
264
+ workflow.add_edge("refine_response", "score_groundedness")
265
+
266
+ workflow.add_conditional_edges(
267
+ "check_precision",
268
+ should_continue_precision,
269
+ {
270
+ "pass": END,
271
+ "refine_query": "refine_query",
272
+ "max_iterations_reached": "max_iterations_reached"
273
+ }
274
+ )
275
+ workflow.add_edge("refine_query", "expand_query")
276
+ workflow.add_edge("max_iterations_reached", END)
277
+ return workflow
278
+
279
+ WORKFLOW_APP = create_workflow().compile()
280
+
281
+ @tool
282
+ def agentic_rag(query: str):
283
+ """Runs the RAG-based agent."""
284
+ inputs = {
285
+ "query": query,
286
+ "expanded_query": "",
287
+ "context": [],
288
+ "response": "",
289
+ "precision_score": 0.0,
290
+ "groundedness_score": 0.0,
291
+ "groundedness_loop_count": 0,
292
+ "precision_loop_count": 0,
293
+ "feedback": "",
294
+ "query_feedback": "",
295
+ "loop_max_iter": 3
296
+ }
297
+ output = WORKFLOW_APP.invoke(inputs)
298
+ return output['response']
299
+
300
+ #================================ Guardrails ===========================#
301
+ llama_guard_client = Groq(api_key=llama_api_key)
302
+ def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
303
+ try:
304
+ response = llama_guard_client.chat.completions.create(
305
+ messages=[{"role": "user", "content": user_input}],
306
+ model=model,
307
+ )
308
+ return response.choices[0].message.content.strip()
309
+ except Exception as e:
310
+ print(f"Error with Llama Guard: {e}")
311
+ return "safe"
312
+
313
+ #============================= Memory & Chatbot ===============================#
314
+ class NutritionBot:
315
+ def __init__(self):
316
+ self.memory = MemoryClient(api_key=MEM0_api_key)
317
+ self.client = ChatOpenAI(
318
+ model_name="gpt-4o-mini",
319
+ api_key=api_key,
320
+ base_url=endpoint,
321
+ temperature=0
322
+ )
323
+ tools = [agentic_rag]
324
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
325
+ Guidelines for Interaction:
326
+ Maintain a polite, professional, and reassuring tone.
327
+ Show genuine empathy for customer concerns and health challenges.
328
+ Reference past interactions to provide personalized and consistent advice.
329
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
330
+ Ensure consistent and accurate information across conversations.
331
+ If any detail is unclear or missing, proactively ask for clarification.
332
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
333
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
334
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
335
+
336
+ """
337
+ prompt = ChatPromptTemplate.from_messages([
338
+ ("system", system_prompt),
339
+ ("human", "{input}"),
340
+ ("placeholder", "{agent_scratchpad}")
341
+ ])
342
+ agent = create_tool_calling_agent(self.client, tools, prompt)
343
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
344
+
345
+ def store_customer_interaction(self, user_id, message, response, metadata=None):
346
+ if metadata is None: metadata = {}
347
+ metadata["timestamp"] = datetime.now().isoformat()
348
+ conversation = [{"role": "user", "content": message}, {"role": "assistant", "content": response}]
349
+ self.memory.add(conversation, user_id=user_id, output_format="v1.1", metadata=metadata)
350
+
351
+ def get_relevant_history(self, user_id, query):
352
+ return self.memory.search(query=query, user_id=user_id, limit=3)
353
+
354
+ def handle_customer_query(self, user_id, query):
355
+ relevant_history = self.get_relevant_history(user_id, query)
356
+ context = "Previous interactions:\n"
357
+ for memory in relevant_history:
358
+ context += f"Memory: {memory['memory']}\n---\n"
359
+
360
+ prompt = f"Context:\n{context}\nQuery: {query}"
361
+ response = self.agent_executor.invoke({"input": prompt})
362
+ self.store_customer_interaction(user_id, query, response["output"], metadata={"type": "query"})
363
+ return response['output']
364
+
365
+ #===================== Streamlit UI ===========================#
366
+ def nutrition_disorder_streamlit():
367
+ st.title("Nutrition Disorder Specialist")
368
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
369
+ st.write("Type 'exit' to end the conversation.")
370
+
371
+ if 'chat_history' not in st.session_state:
372
+ st.session_state.chat_history = []
373
+ if 'user_id' not in st.session_state:
374
+ st.session_state.user_id = None
375
+
376
+ if st.session_state.user_id is None:
377
+ with st.form("login_form", clear_on_submit=True):
378
+ user_id = st.text_input("Enter your name to begin:")
379
+ submit_button = st.form_submit_button("Login")
380
+ if submit_button and user_id:
381
+ st.session_state.user_id = user_id
382
+ st.session_state.chat_history.append({"role": "assistant", "content": f"Welcome {user_id}! How can I help you?"})
383
+ st.session_state.login_submitted = True
384
+ if st.session_state.get("login_submitted", False):
385
+ st.session_state.pop("login_submitted")
386
+ st.rerun()
387
+ else:
388
+ for message in st.session_state.chat_history:
389
+ with st.chat_message(message["role"]):
390
+ st.write(message["content"])
391
+
392
+ user_query = st.chat_input("Type your question here (or 'exit' to end)...")
393
+ if user_query:
394
+ if user_query.lower() == "exit":
395
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
396
+ with st.chat_message("user"):
397
+ st.write("exit")
398
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions."
399
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
400
+ with st.chat_message("assistant"):
401
+ st.write(goodbye_msg)
402
+ st.session_state.user_id = None
403
+ st.rerun()
404
+ return
405
+
406
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
407
+ with st.chat_message("user"):
408
+ st.write(user_query)
409
+
410
+ filtered_result = filter_input_with_llama_guard(user_query).replace("\n", " ")
411
+ if filtered_result in ["safe", "unsafe S6", "unsafe S7"]:
412
+ try:
413
+ if 'chatbot' not in st.session_state:
414
+ st.session_state.chatbot = NutritionBot()
415
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
416
+ st.write(response)
417
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
418
+ except Exception as e:
419
+ error_msg = f"Error: {str(e)}"
420
+ st.write(error_msg)
421
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
422
+ else:
423
+ msg = "I apologize, but I cannot process that input as it may be inappropriate."
424
+ st.write(msg)
425
+ st.session_state.chat_history.append({"role": "assistant", "content": msg})
426
+
427
+ if __name__ == "__main__":
428
+ nutrition_disorder_streamlit()