Medical / app.py
shreyankisiri's picture
Update app.py
97e217c verified
import os
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from dotenv import load_dotenv
load_dotenv()
from langchain_groq import ChatGroq
from pydantic import BaseModel, field_validator
from typing import Literal
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableParallel
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import Tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langchain_core.documents import Document
from langgraph.graph import StateGraph, END
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
GROQ_API_KEY = os.getenv("GROQQ_API_KEY")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"]=LANGCHAIN_API_KEY
os.environ["LANGCHAIN_PROJECT"]="advanced-rag"
os.environ["TAVILY_API_KEY"]=TAVILY_API_KEY
urls = [
"https://www.webmd.com/a-to-z-guides/malaria",
"https://www.webmd.com/diabetes/type-1-diabetes",
"https://www.webmd.com/diabetes/type-2-diabetes",
"https://www.webmd.com/migraines-headaches/migraines-headaches-migraines",
]
loader = WebBaseLoader(
urls ,
bs_get_text_kwargs={'strip':True}
)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size = 500 , chunk_overlap = 30)
chunks = splitter.split_documents(docs)
embedding_function = HuggingFaceEmbeddings()
vector_store = Chroma.from_documents(documents = chunks , embedding = embedding_function)
retriever = vector_store.as_retriever()
from pydantic import BaseModel , Field
class VectorStore(BaseModel):
(
"A vectorstore contains information about symptoms, treatment"
", risk factors and other information about malaria, type 1 and"
"type 2 diabetes and migraines"
)
query: str
class SearchEngine(BaseModel):
''' Searhc engine for othe medical info in web for that formatting here'''
query : str
router_prompt_temp = (
"You are an expert in routing user queries to either a VectorStore, SearchEngine\n"
"Use SearchEngine for all other medical queries that are not related to malaria, diabetes, or migraines.\n"
"The VectorStore contains information on malaria, diabetes, and migraines.\n"
'Note that if a query is not medically-related, you must output "not medically-related", don\'t try to use any tool.\n\n'
"query: {query}"
)
llm = ChatGroq(model="llama3-70b-8192", temperature=0)
prompt = ChatPromptTemplate.from_template(router_prompt_temp)
tools = [VectorStore , SearchEngine]
llm_with_tools = llm.bind_tools(tools)
question_router = prompt | llm_with_tools
class Grader(BaseModel):
"""This format checks how relevant the retrieved docs are."""
grade: Literal["relevant", "irrelevant"]
@field_validator("grade", mode="before")
def validate_grade(cls, value):
if value == "not relevant":
return "irrelevant"
return value
grader_system_prompt_template = """"You are a grader tasked with assessing the relevance of a given context to a query.
If the context is relevant to the query, score it as "relevant". Otherwise, give "irrelevant".
Do not answer the actual answer, just provide the grade in JSON format with "grade" as the key, without any additional explanation."
"""
grader_prompt = ChatPromptTemplate.from_messages([
("system",grader_system_prompt_template),
("human","context is : {context}\n\n query : {query}")
])
llm_with_structured = llm.with_structured_output(Grader , method = 'json_mode')
grader_chain = grader_prompt | llm_with_structured
rag_template_str = (
"You are a helpful assistant. Answer the query below based only on the provided context.\n\n"
"context: {context}\n\n"
"query: {query}"
)
rag_prompt = ChatPromptTemplate.from_template(rag_template_str)
rag_chain = rag_prompt | llm | StrOutputParser()
fall_back_template = "You are a friendly medical assistant created by NHVAI.\n"
"Do not respond to queries that are not related to health.\n"
"If a query is not related to health, acknowledge your limitations.\n"
"Provide concise responses to only medically-related queries.\n\n"
"Current conversations:\n\n{chat_history}\n\n"
"human: {query}"
fall_back_prompt = ChatPromptTemplate.from_template(fall_back_template)
chat_history = lambda x: "\n".join(
[
(
f"human: {msg.content}" if isinstance(msg, HumanMessage) else f"AI: {msg.content}"
)
for msg in x["chat_history"] if hasattr(msg, "content")
]
)
fallback_chain = (
{"chat_history":chat_history , "query":itemgetter("query")}
| fall_back_prompt
| llm
| StrOutputParser()
)
class HallucinationGrader(BaseModel):
"Binary score for hallucination check in llm's response"
grade: Literal["yes", "no"] = Field(
..., description="'yes' if the llm's reponse is hallucinated otherwise 'no'"
)
hallucination_grader_system_prompt_template = (
"You are a grader assessing whether a response from an llm is based on a given context.\n"
"If the llm's response is not based on the given context give a score of 'yes' meaning it's a hallucination"
"otherwise give 'no'\n"
"Just give the grade in json with 'grade' as a key and a binary value of 'yes' or 'no' without additional explanation"
)
hallucination_grader_prompt = ChatPromptTemplate.from_messages(
[
("system", hallucination_grader_system_prompt_template),
("human", "context: {context}\n\nllm's response: {response}"),
]
)
hallucination_grader_chain = (
RunnableParallel(
{
"response": itemgetter("response"),
"context": lambda x: "\n\n".join([c.page_content for c in x["context"]]),
}
)
| hallucination_grader_prompt
| llm.with_structured_output(HallucinationGrader, method="json_mode")
)
class AnswerGrader(BaseModel):
"Binary score for an answer check based on a query."
grade: Literal["yes", "no"] = Field(
...,
description="'yes' if the provided answer is an actual answer to the query otherwise 'no'",
)
answer_grader_system_prompt_template = (
"You are a grader assessing whether a provided answer is in fact an answer to the given query.\n"
"If the provided answer does not answer the query give a score of 'no' otherwise give 'yes'\n"
"Just give the grade in json with 'grade' as a key and a binary value of 'yes' or 'no' without additional explanation"
)
answer_grader_prompt = ChatPromptTemplate.from_messages(
[
("system", answer_grader_system_prompt_template),
("human", "query: {query}\n\nanswer: {response}"),
]
)
answer_grader_chain = answer_grader_prompt | llm.with_structured_output(
AnswerGrader, method="json_mode"
)
tavily_search = TavilySearchResults()
vectorstore = Tool(name = 'VectorStore',func = retriever.invoke , description="Useful to search the vector database")
searchengine = Tool(name = "SearchEngine",func=tavily_search , description = "useful to search the web")
tools = [vectorstore,searchengine]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Make sure to use the tavily_search tool for information if the given query doesnt relate to the vectorstore content.",
),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
class AgentState(TypedDict):
"""The dictionary keeps track of the data required by the various nodes in the graph"""
query : str
generation : str
chat_history : list[BaseMessage]
documents : list[Document]
def retrieve_node(state:AgentState) -> dict[str,list[Document] | str] :
"""
Retrieve relevent documents from the vectorstore
query: str
return list[Document]
"""
print(f"retrieve node ")
query = state['query']
documents = retriever.invoke(input = query)
return {"documents":documents}
def fallback_node(state:AgentState):
''' Fallback to this node when there is no tool call'''
print(f"fallback node ")
query = state['query']
chat_history = state['chat_history']
generation = fallback_chain.invoke({"query":query,'chat_history':chat_history})
return {"generation":generation}
def filter_documents_node(state:AgentState):
filtered_docs = list()
query = state['query']
documents = state['documents']
print(f"filter docs node ")
for i,docs in enumerate(documents,start = 1):
grade = grader_chain.invoke({"query":query,"context":docs})
if grade.grade == 'relevant':
print(f"Chuck.......{i} is relevent")
filtered_docs.append(docs)
else:
print(f"Chuck.....{i} is irrelevent")
return {"documents":filtered_docs}
def rag_node(state:AgentState):
print(f"rag node ")
query = state['query']
documents = state['documents']
generation = rag_chain.invoke({"query":query , 'context':documents})
return {"generation": generation}
def web_search_node(state:AgentState):
print(f"search node ")
query = state['query']
results = tavily_search.invoke(query)
documents = [
Document(page_content = doc['content'],metadata = {'source':doc['url']}) for doc in results
]
return {"documents":documents}
def question_router_node(state:AgentState):
print("router node")
query = state['query']
try:
response = question_router.invoke({'query':query})
except Exception:
return "llm_feedback"
if 'tool_calls' not in response.additional_kwargs:
print('-----No tools called--------')
return 'llm_feedback'
if len(response.additional_kwargs["tool_calls"]) == 0:
raise "Router could not decide route!"
route = response.additional_kwargs['tool_calls'][0]['function']['name']
if route =='VectorStore':
print("Routing to the vector store....")
return "VectorStore"
elif route == 'SearchEngine':
print("Routing to search enginee")
return "SearchEngine"
def should_generate(state: dict):
print("should generate node")
filtered_docs = state["documents"]
if not filtered_docs:
print("---All retrived documents not relevant---")
return "SearchEngine"
else:
print("---Some retrived documents are relevant---")
return "generate"
def hallucination_and_answer_relevance_check(state: dict):
print("hallucination node")
llm_response = state["generation"]
documents = state["documents"]
query = state["query"]
hallucination_grade = hallucination_grader_chain.invoke(
{"response": llm_response, "context": documents}
)
if hallucination_grade.grade == "no":
print("---Hallucination check passed---")
answer_relevance_grade = answer_grader_chain.invoke(
{"response": llm_response, "query": query}
)
if answer_relevance_grade.grade == "yes":
print("---Answer is relevant to question---\n")
return "useful"
else:
print("---Answer is not relevant to question---")
return "not useful"
print("---Hallucination check failed---")
return "generate"
workflow = StateGraph(AgentState)
workflow.add_node("VectorStore", retrieve_node)
workflow.add_node("SearchEngine", web_search_node)
workflow.add_node("filter_docs", filter_documents_node)
workflow.add_node("fallback", fallback_node)
workflow.add_node("rag", rag_node)
workflow.set_conditional_entry_point(
question_router_node,
{
"llm_fallback": "fallback",
"VectorStore": "VectorStore",
"SearchEngine": "SearchEngine",
"llm_feedback": "fallback",
},
)
workflow.add_edge("VectorStore", "filter_docs")
workflow.add_edge("SearchEngine", "filter_docs")
workflow.add_conditional_edges(
"filter_docs", should_generate, {"SearchEngine": "SearchEngine", "generate": "rag"}
)
workflow.add_conditional_edges(
"rag",
hallucination_and_answer_relevance_check,
{"useful": END, "not useful": "SearchEngine", "generate": "rag"},
)
workflow.add_edge("fallback", END)
graph = workflow.compile()
def ask(query , chat_history):
return graph.invoke({"query":query,"chat_history":chat_history})
import gradio as gr
def respond(message, history, system_message):
"""
Handles user input, sends it to the LangGraph pipeline, and returns the response.
"""
# Append system message only for the first query
if not history:
history.append(("System", system_message))
# Invoke the LangGraph pipeline
result = ask(message, history)
# Extract AI response
response = result.get("generation", "I'm not sure how to answer that.")
# Append user message and AI response to history
history.append((message, response))
# Return the response and updated chat history
return response, history
# Define Gradio Chat Interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a helpful medical chatbot.", label="System Message"),
],
)
if __name__ == "__main__":
demo.launch()