import re from enum import Enum from typing import Optional, List, Dict, Any import torch from dotenv import load_dotenv from langchain.chains.summarize.refine_prompts import prompt_template from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_huggingface import HuggingFacePipeline from langchain_tavily import TavilySearch from langgraph.graph import StateGraph, START, END from transformers import pipeline from typing_extensions import TypedDict import gradio as gr load_dotenv() if torch.cuda.is_available(): print("✅ CUDA GPU is available!") else: print("❌ CUDA GPU is NOT available.") # Tavily Search setup TavilySearchTool = TavilySearch( max_results=5, topic="general", ) # Define State class State(TypedDict): query: Optional[str] task_type: Optional[str] messages: List[Dict[str, Any]] # Graph setup graph_builder = StateGraph(State) # HF pipeline pipe = pipeline("text-generation", model="Qwen/Qwen2.5-Coder-3B-Instruct", max_new_tokens=300, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2) llm = HuggingFacePipeline(pipeline=pipe) # TaskType enum class TaskType(Enum): INQUIRY = "inquiry" MEDICAL_TASK = "medical task" def fetch_assistant_response(response: str) -> str: pattern = r"[\s\n]*Assistant:\s*" matches = list(re.finditer(pattern, response)) if matches: last_match = matches[-1] response_start_index = last_match.end() response_text = response[response_start_index:] return response_text.strip() else: return "" def parse_response_as_enum(response: str): response = response.strip().lower() if "medical task" in response: return TaskType.MEDICAL_TASK else: return TaskType.INQUIRY def classify_task(state: State): user_query = state.get("query", "") prompt_template = ( "Classify the user query as either 'inquiry' or 'medical task'. " "'Inquiry' is a general question without medical analysis. " "'Medical task' involves analyzing or classifying medical data but not giving treatment advice. " "Examples:\n" "- User query: What is a heart attack? → inquiry\n" "- User query: What is a normal cholesterol level? → inquiry\n" "- User query: Classify these ECG readings: ... → medical task\n" "- User query: Check if these blood pressure readings are normal: 120/80 → medical task\n" "Only reply with 'inquiry' or 'medical task' and nothing else. " "User query: {query}\n\nAssistant:" ) prompt = PromptTemplate.from_template(prompt_template) llm_chain = prompt | llm | fetch_assistant_response | parse_response_as_enum response = llm_chain.invoke({"query": user_query}) return {"task_type": response.value} def route(state: State): task_type = state.get("task_type") return task_type def format_results(search_results: Dict[str, Any]): context = [] results = search_results["results"] for result in results: if result["score"] > 0.5: context.append(result["content"]) context = "\n\n".join(context) return {"query": search_results["query"], "context": context} def handle_inquiry(state: State): prompt_template = ( "Given the following context:{context}" "Answer the user's query:{query}" "\n\nAssistant:" ) prompt = PromptTemplate.from_template(prompt_template) rag_chain = TavilySearchTool | format_results | prompt | llm | StrOutputParser() | fetch_assistant_response assistant_response = rag_chain.invoke({"query": state.get("query")}) return {"messages": [{"role": "user", "content": state.get("query")}, {"role": "assistant", "content": assistant_response}]} def handle_medical_task(state: State): return {"messages": [{"role": "user", "content": state.get("query")}, {"role": "assistant", "content": "This is a medical task. Further implementation needed."}]} # Graph nodes CLASSIFY_TASK = "classify_task" HANDLE_INQUIRY = "handle_inquiry" HANDLE_MEDICAL_TASK = "handle_medical_task" graph_builder.add_node(CLASSIFY_TASK, classify_task) graph_builder.add_node(HANDLE_INQUIRY, handle_inquiry) graph_builder.add_node(HANDLE_MEDICAL_TASK, handle_medical_task) graph_builder.add_edge(START, CLASSIFY_TASK) graph_builder.add_conditional_edges(CLASSIFY_TASK, route, { "inquiry": HANDLE_INQUIRY, "medical task": HANDLE_MEDICAL_TASK }) graph_builder.add_edge(HANDLE_INQUIRY, END) graph_builder.add_edge(HANDLE_MEDICAL_TASK, END) compiled_graph = graph_builder.compile() # Gradio interface function def gradio_agent(query): state = {"query": query, "messages": []} for chunk in compiled_graph.stream({"query": query}): if "handle_inquiry" in chunk: return chunk["handle_inquiry"]["messages"][-1]["content"] return "No response generated." iface = gr.Interface(fn=gradio_agent, inputs="text", outputs="text", title="LangGraph Medical Agent", description="Ask a question, and the agent will classify and answer appropriately.") if __name__ == "__main__": iface.launch()