aldsouza's picture
Update app.py
d64c758 verified
import re
from enum import Enum
from typing import Optional, List, Dict, Any
import torch
from dotenv import load_dotenv
from langchain.chains.summarize.refine_prompts import prompt_template
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFacePipeline
from langchain_tavily import TavilySearch
from langgraph.graph import StateGraph, START, END
from transformers import pipeline
from typing_extensions import TypedDict
import gradio as gr
load_dotenv()
if torch.cuda.is_available():
print("βœ… CUDA GPU is available!")
else:
print("❌ CUDA GPU is NOT available.")
# Tavily Search setup
TavilySearchTool = TavilySearch(
max_results=5,
topic="general",
)
# Define State
class State(TypedDict):
query: Optional[str]
task_type: Optional[str]
messages: List[Dict[str, Any]]
# Graph setup
graph_builder = StateGraph(State)
# HF pipeline
pipe = pipeline("text-generation", model="Qwen/Qwen2.5-Coder-3B-Instruct", max_new_tokens=300, do_sample=True,
temperature=0.7, top_p=0.9, repetition_penalty=1.2)
llm = HuggingFacePipeline(pipeline=pipe)
# TaskType enum
class TaskType(Enum):
INQUIRY = "inquiry"
MEDICAL_TASK = "medical task"
def fetch_assistant_response(response: str) -> str:
pattern = r"[\s\n]*Assistant:\s*"
matches = list(re.finditer(pattern, response))
if matches:
last_match = matches[-1]
response_start_index = last_match.end()
response_text = response[response_start_index:]
return response_text.strip()
else:
return ""
def parse_response_as_enum(response: str):
response = response.strip().lower()
if "medical task" in response:
return TaskType.MEDICAL_TASK
else:
return TaskType.INQUIRY
def classify_task(state: State):
user_query = state.get("query", "")
prompt_template = (
"Classify the user query as either 'inquiry' or 'medical task'. "
"'Inquiry' is a general question without medical analysis. "
"'Medical task' involves analyzing or classifying medical data but not giving treatment advice. "
"Examples:\n"
"- User query: What is a heart attack? β†’ inquiry\n"
"- User query: What is a normal cholesterol level? β†’ inquiry\n"
"- User query: Classify these ECG readings: ... β†’ medical task\n"
"- User query: Check if these blood pressure readings are normal: 120/80 β†’ medical task\n"
"Only reply with 'inquiry' or 'medical task' and nothing else. "
"User query: {query}\n\nAssistant:"
)
prompt = PromptTemplate.from_template(prompt_template)
llm_chain = prompt | llm | fetch_assistant_response | parse_response_as_enum
response = llm_chain.invoke({"query": user_query})
return {"task_type": response.value}
def route(state: State):
task_type = state.get("task_type")
return task_type
def format_results(search_results: Dict[str, Any]):
context = []
results = search_results["results"]
for result in results:
if result["score"] > 0.5:
context.append(result["content"])
context = "\n\n".join(context)
return {"query": search_results["query"], "context": context}
def handle_inquiry(state: State):
prompt_template = (
"Given the following context:{context}"
"Answer the user's query:{query}"
"\n\nAssistant:"
)
prompt = PromptTemplate.from_template(prompt_template)
rag_chain = TavilySearchTool | format_results | prompt | llm | StrOutputParser() | fetch_assistant_response
assistant_response = rag_chain.invoke({"query": state.get("query")})
return {"messages": [{"role": "user", "content": state.get("query")},
{"role": "assistant", "content": assistant_response}]}
def handle_medical_task(state: State):
return {"messages": [{"role": "user", "content": state.get("query")},
{"role": "assistant", "content": "This is a medical task. Further implementation needed."}]}
# Graph nodes
CLASSIFY_TASK = "classify_task"
HANDLE_INQUIRY = "handle_inquiry"
HANDLE_MEDICAL_TASK = "handle_medical_task"
graph_builder.add_node(CLASSIFY_TASK, classify_task)
graph_builder.add_node(HANDLE_INQUIRY, handle_inquiry)
graph_builder.add_node(HANDLE_MEDICAL_TASK, handle_medical_task)
graph_builder.add_edge(START, CLASSIFY_TASK)
graph_builder.add_conditional_edges(CLASSIFY_TASK, route, {
"inquiry": HANDLE_INQUIRY, "medical task": HANDLE_MEDICAL_TASK
})
graph_builder.add_edge(HANDLE_INQUIRY, END)
graph_builder.add_edge(HANDLE_MEDICAL_TASK, END)
compiled_graph = graph_builder.compile()
# Gradio interface function
def gradio_agent(query):
state = {"query": query, "messages": []}
for chunk in compiled_graph.stream({"query": query}):
if "handle_inquiry" in chunk:
return chunk["handle_inquiry"]["messages"][-1]["content"]
return "No response generated."
iface = gr.Interface(fn=gradio_agent, inputs="text", outputs="text", title="LangGraph Medical Agent",
description="Ask a question, and the agent will classify and answer appropriately.")
if __name__ == "__main__":
iface.launch()