Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from langchain_groq import ChatGroq | |
| from pydantic import BaseModel, field_validator | |
| from typing import Literal | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.messages import HumanMessage | |
| from langchain_core.runnables import RunnableParallel | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_core.tools import Tool | |
| from langchain.agents import AgentExecutor, create_tool_calling_agent | |
| from typing import TypedDict | |
| from langchain_core.messages import BaseMessage | |
| from langchain_core.documents import Document | |
| from langgraph.graph import StateGraph, END | |
| from operator import itemgetter | |
| from langchain_core.prompts import ChatPromptTemplate | |
| GROQ_API_KEY = os.getenv("GROQQ_API_KEY") | |
| LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY") | |
| TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
| os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
| os.environ["LANGCHAIN_TRACING_V2"]="true" | |
| os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com" | |
| os.environ["LANGCHAIN_API_KEY"]=LANGCHAIN_API_KEY | |
| os.environ["LANGCHAIN_PROJECT"]="advanced-rag" | |
| os.environ["TAVILY_API_KEY"]=TAVILY_API_KEY | |
| urls = [ | |
| "https://www.webmd.com/a-to-z-guides/malaria", | |
| "https://www.webmd.com/diabetes/type-1-diabetes", | |
| "https://www.webmd.com/diabetes/type-2-diabetes", | |
| "https://www.webmd.com/migraines-headaches/migraines-headaches-migraines", | |
| ] | |
| loader = WebBaseLoader( | |
| urls , | |
| bs_get_text_kwargs={'strip':True} | |
| ) | |
| docs = loader.load() | |
| splitter = RecursiveCharacterTextSplitter(chunk_size = 500 , chunk_overlap = 30) | |
| chunks = splitter.split_documents(docs) | |
| embedding_function = HuggingFaceEmbeddings() | |
| vector_store = Chroma.from_documents(documents = chunks , embedding = embedding_function) | |
| retriever = vector_store.as_retriever() | |
| from pydantic import BaseModel , Field | |
| class VectorStore(BaseModel): | |
| ( | |
| "A vectorstore contains information about symptoms, treatment" | |
| ", risk factors and other information about malaria, type 1 and" | |
| "type 2 diabetes and migraines" | |
| ) | |
| query: str | |
| class SearchEngine(BaseModel): | |
| ''' Searhc engine for othe medical info in web for that formatting here''' | |
| query : str | |
| router_prompt_temp = ( | |
| "You are an expert in routing user queries to either a VectorStore, SearchEngine\n" | |
| "Use SearchEngine for all other medical queries that are not related to malaria, diabetes, or migraines.\n" | |
| "The VectorStore contains information on malaria, diabetes, and migraines.\n" | |
| 'Note that if a query is not medically-related, you must output "not medically-related", don\'t try to use any tool.\n\n' | |
| "query: {query}" | |
| ) | |
| llm = ChatGroq(model="llama3-70b-8192", temperature=0) | |
| prompt = ChatPromptTemplate.from_template(router_prompt_temp) | |
| tools = [VectorStore , SearchEngine] | |
| llm_with_tools = llm.bind_tools(tools) | |
| question_router = prompt | llm_with_tools | |
| class Grader(BaseModel): | |
| """This format checks how relevant the retrieved docs are.""" | |
| grade: Literal["relevant", "irrelevant"] | |
| def validate_grade(cls, value): | |
| if value == "not relevant": | |
| return "irrelevant" | |
| return value | |
| grader_system_prompt_template = """"You are a grader tasked with assessing the relevance of a given context to a query. | |
| If the context is relevant to the query, score it as "relevant". Otherwise, give "irrelevant". | |
| Do not answer the actual answer, just provide the grade in JSON format with "grade" as the key, without any additional explanation." | |
| """ | |
| grader_prompt = ChatPromptTemplate.from_messages([ | |
| ("system",grader_system_prompt_template), | |
| ("human","context is : {context}\n\n query : {query}") | |
| ]) | |
| llm_with_structured = llm.with_structured_output(Grader , method = 'json_mode') | |
| grader_chain = grader_prompt | llm_with_structured | |
| rag_template_str = ( | |
| "You are a helpful assistant. Answer the query below based only on the provided context.\n\n" | |
| "context: {context}\n\n" | |
| "query: {query}" | |
| ) | |
| rag_prompt = ChatPromptTemplate.from_template(rag_template_str) | |
| rag_chain = rag_prompt | llm | StrOutputParser() | |
| fall_back_template = "You are a friendly medical assistant created by NHVAI.\n" | |
| "Do not respond to queries that are not related to health.\n" | |
| "If a query is not related to health, acknowledge your limitations.\n" | |
| "Provide concise responses to only medically-related queries.\n\n" | |
| "Current conversations:\n\n{chat_history}\n\n" | |
| "human: {query}" | |
| fall_back_prompt = ChatPromptTemplate.from_template(fall_back_template) | |
| chat_history = lambda x: "\n".join( | |
| [ | |
| ( | |
| f"human: {msg.content}" if isinstance(msg, HumanMessage) else f"AI: {msg.content}" | |
| ) | |
| for msg in x["chat_history"] if hasattr(msg, "content") | |
| ] | |
| ) | |
| fallback_chain = ( | |
| {"chat_history":chat_history , "query":itemgetter("query")} | |
| | fall_back_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| class HallucinationGrader(BaseModel): | |
| "Binary score for hallucination check in llm's response" | |
| grade: Literal["yes", "no"] = Field( | |
| ..., description="'yes' if the llm's reponse is hallucinated otherwise 'no'" | |
| ) | |
| hallucination_grader_system_prompt_template = ( | |
| "You are a grader assessing whether a response from an llm is based on a given context.\n" | |
| "If the llm's response is not based on the given context give a score of 'yes' meaning it's a hallucination" | |
| "otherwise give 'no'\n" | |
| "Just give the grade in json with 'grade' as a key and a binary value of 'yes' or 'no' without additional explanation" | |
| ) | |
| hallucination_grader_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", hallucination_grader_system_prompt_template), | |
| ("human", "context: {context}\n\nllm's response: {response}"), | |
| ] | |
| ) | |
| hallucination_grader_chain = ( | |
| RunnableParallel( | |
| { | |
| "response": itemgetter("response"), | |
| "context": lambda x: "\n\n".join([c.page_content for c in x["context"]]), | |
| } | |
| ) | |
| | hallucination_grader_prompt | |
| | llm.with_structured_output(HallucinationGrader, method="json_mode") | |
| ) | |
| class AnswerGrader(BaseModel): | |
| "Binary score for an answer check based on a query." | |
| grade: Literal["yes", "no"] = Field( | |
| ..., | |
| description="'yes' if the provided answer is an actual answer to the query otherwise 'no'", | |
| ) | |
| answer_grader_system_prompt_template = ( | |
| "You are a grader assessing whether a provided answer is in fact an answer to the given query.\n" | |
| "If the provided answer does not answer the query give a score of 'no' otherwise give 'yes'\n" | |
| "Just give the grade in json with 'grade' as a key and a binary value of 'yes' or 'no' without additional explanation" | |
| ) | |
| answer_grader_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", answer_grader_system_prompt_template), | |
| ("human", "query: {query}\n\nanswer: {response}"), | |
| ] | |
| ) | |
| answer_grader_chain = answer_grader_prompt | llm.with_structured_output( | |
| AnswerGrader, method="json_mode" | |
| ) | |
| tavily_search = TavilySearchResults() | |
| vectorstore = Tool(name = 'VectorStore',func = retriever.invoke , description="Useful to search the vector database") | |
| searchengine = Tool(name = "SearchEngine",func=tavily_search , description = "useful to search the web") | |
| tools = [vectorstore,searchengine] | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| "You are a helpful assistant. Make sure to use the tavily_search tool for information if the given query doesnt relate to the vectorstore content.", | |
| ), | |
| ("placeholder", "{chat_history}"), | |
| ("human", "{input}"), | |
| ("placeholder", "{agent_scratchpad}"), | |
| ] | |
| ) | |
| agent = create_tool_calling_agent(llm, tools, prompt) | |
| agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) | |
| class AgentState(TypedDict): | |
| """The dictionary keeps track of the data required by the various nodes in the graph""" | |
| query : str | |
| generation : str | |
| chat_history : list[BaseMessage] | |
| documents : list[Document] | |
| def retrieve_node(state:AgentState) -> dict[str,list[Document] | str] : | |
| """ | |
| Retrieve relevent documents from the vectorstore | |
| query: str | |
| return list[Document] | |
| """ | |
| print(f"retrieve node ") | |
| query = state['query'] | |
| documents = retriever.invoke(input = query) | |
| return {"documents":documents} | |
| def fallback_node(state:AgentState): | |
| ''' Fallback to this node when there is no tool call''' | |
| print(f"fallback node ") | |
| query = state['query'] | |
| chat_history = state['chat_history'] | |
| generation = fallback_chain.invoke({"query":query,'chat_history':chat_history}) | |
| return {"generation":generation} | |
| def filter_documents_node(state:AgentState): | |
| filtered_docs = list() | |
| query = state['query'] | |
| documents = state['documents'] | |
| print(f"filter docs node ") | |
| for i,docs in enumerate(documents,start = 1): | |
| grade = grader_chain.invoke({"query":query,"context":docs}) | |
| if grade.grade == 'relevant': | |
| print(f"Chuck.......{i} is relevent") | |
| filtered_docs.append(docs) | |
| else: | |
| print(f"Chuck.....{i} is irrelevent") | |
| return {"documents":filtered_docs} | |
| def rag_node(state:AgentState): | |
| print(f"rag node ") | |
| query = state['query'] | |
| documents = state['documents'] | |
| generation = rag_chain.invoke({"query":query , 'context':documents}) | |
| return {"generation": generation} | |
| def web_search_node(state:AgentState): | |
| print(f"search node ") | |
| query = state['query'] | |
| results = tavily_search.invoke(query) | |
| documents = [ | |
| Document(page_content = doc['content'],metadata = {'source':doc['url']}) for doc in results | |
| ] | |
| return {"documents":documents} | |
| def question_router_node(state:AgentState): | |
| print("router node") | |
| query = state['query'] | |
| try: | |
| response = question_router.invoke({'query':query}) | |
| except Exception: | |
| return "llm_feedback" | |
| if 'tool_calls' not in response.additional_kwargs: | |
| print('-----No tools called--------') | |
| return 'llm_feedback' | |
| if len(response.additional_kwargs["tool_calls"]) == 0: | |
| raise "Router could not decide route!" | |
| route = response.additional_kwargs['tool_calls'][0]['function']['name'] | |
| if route =='VectorStore': | |
| print("Routing to the vector store....") | |
| return "VectorStore" | |
| elif route == 'SearchEngine': | |
| print("Routing to search enginee") | |
| return "SearchEngine" | |
| def should_generate(state: dict): | |
| print("should generate node") | |
| filtered_docs = state["documents"] | |
| if not filtered_docs: | |
| print("---All retrived documents not relevant---") | |
| return "SearchEngine" | |
| else: | |
| print("---Some retrived documents are relevant---") | |
| return "generate" | |
| def hallucination_and_answer_relevance_check(state: dict): | |
| print("hallucination node") | |
| llm_response = state["generation"] | |
| documents = state["documents"] | |
| query = state["query"] | |
| hallucination_grade = hallucination_grader_chain.invoke( | |
| {"response": llm_response, "context": documents} | |
| ) | |
| if hallucination_grade.grade == "no": | |
| print("---Hallucination check passed---") | |
| answer_relevance_grade = answer_grader_chain.invoke( | |
| {"response": llm_response, "query": query} | |
| ) | |
| if answer_relevance_grade.grade == "yes": | |
| print("---Answer is relevant to question---\n") | |
| return "useful" | |
| else: | |
| print("---Answer is not relevant to question---") | |
| return "not useful" | |
| print("---Hallucination check failed---") | |
| return "generate" | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("VectorStore", retrieve_node) | |
| workflow.add_node("SearchEngine", web_search_node) | |
| workflow.add_node("filter_docs", filter_documents_node) | |
| workflow.add_node("fallback", fallback_node) | |
| workflow.add_node("rag", rag_node) | |
| workflow.set_conditional_entry_point( | |
| question_router_node, | |
| { | |
| "llm_fallback": "fallback", | |
| "VectorStore": "VectorStore", | |
| "SearchEngine": "SearchEngine", | |
| "llm_feedback": "fallback", | |
| }, | |
| ) | |
| workflow.add_edge("VectorStore", "filter_docs") | |
| workflow.add_edge("SearchEngine", "filter_docs") | |
| workflow.add_conditional_edges( | |
| "filter_docs", should_generate, {"SearchEngine": "SearchEngine", "generate": "rag"} | |
| ) | |
| workflow.add_conditional_edges( | |
| "rag", | |
| hallucination_and_answer_relevance_check, | |
| {"useful": END, "not useful": "SearchEngine", "generate": "rag"}, | |
| ) | |
| workflow.add_edge("fallback", END) | |
| graph = workflow.compile() | |
| def ask(query , chat_history): | |
| return graph.invoke({"query":query,"chat_history":chat_history}) | |
| import gradio as gr | |
| def respond(message, history, system_message): | |
| """ | |
| Handles user input, sends it to the LangGraph pipeline, and returns the response. | |
| """ | |
| # Append system message only for the first query | |
| if not history: | |
| history.append(("System", system_message)) | |
| # Invoke the LangGraph pipeline | |
| result = ask(message, history) | |
| # Extract AI response | |
| response = result.get("generation", "I'm not sure how to answer that.") | |
| # Append user message and AI response to history | |
| history.append((message, response)) | |
| # Return the response and updated chat history | |
| return response, history | |
| # Define Gradio Chat Interface | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a helpful medical chatbot.", label="System Message"), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |