Spaces:
Sleeping
Sleeping
| import re | |
| from enum import Enum | |
| from typing import Optional, List, Dict, Any | |
| import torch | |
| from dotenv import load_dotenv | |
| from langchain.chains.summarize.refine_prompts import prompt_template | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_huggingface import HuggingFacePipeline | |
| from langchain_tavily import TavilySearch | |
| from langgraph.graph import StateGraph, START, END | |
| from transformers import pipeline | |
| from typing_extensions import TypedDict | |
| import gradio as gr | |
| load_dotenv() | |
| if torch.cuda.is_available(): | |
| print("β CUDA GPU is available!") | |
| else: | |
| print("β CUDA GPU is NOT available.") | |
| # Tavily Search setup | |
| TavilySearchTool = TavilySearch( | |
| max_results=5, | |
| topic="general", | |
| ) | |
| # Define State | |
| class State(TypedDict): | |
| query: Optional[str] | |
| task_type: Optional[str] | |
| messages: List[Dict[str, Any]] | |
| # Graph setup | |
| graph_builder = StateGraph(State) | |
| # HF pipeline | |
| pipe = pipeline("text-generation", model="Qwen/Qwen2.5-Coder-3B-Instruct", max_new_tokens=300, do_sample=True, | |
| temperature=0.7, top_p=0.9, repetition_penalty=1.2) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # TaskType enum | |
| class TaskType(Enum): | |
| INQUIRY = "inquiry" | |
| MEDICAL_TASK = "medical task" | |
| def fetch_assistant_response(response: str) -> str: | |
| pattern = r"[\s\n]*Assistant:\s*" | |
| matches = list(re.finditer(pattern, response)) | |
| if matches: | |
| last_match = matches[-1] | |
| response_start_index = last_match.end() | |
| response_text = response[response_start_index:] | |
| return response_text.strip() | |
| else: | |
| return "" | |
| def parse_response_as_enum(response: str): | |
| response = response.strip().lower() | |
| if "medical task" in response: | |
| return TaskType.MEDICAL_TASK | |
| else: | |
| return TaskType.INQUIRY | |
| def classify_task(state: State): | |
| user_query = state.get("query", "") | |
| prompt_template = ( | |
| "Classify the user query as either 'inquiry' or 'medical task'. " | |
| "'Inquiry' is a general question without medical analysis. " | |
| "'Medical task' involves analyzing or classifying medical data but not giving treatment advice. " | |
| "Examples:\n" | |
| "- User query: What is a heart attack? β inquiry\n" | |
| "- User query: What is a normal cholesterol level? β inquiry\n" | |
| "- User query: Classify these ECG readings: ... β medical task\n" | |
| "- User query: Check if these blood pressure readings are normal: 120/80 β medical task\n" | |
| "Only reply with 'inquiry' or 'medical task' and nothing else. " | |
| "User query: {query}\n\nAssistant:" | |
| ) | |
| prompt = PromptTemplate.from_template(prompt_template) | |
| llm_chain = prompt | llm | fetch_assistant_response | parse_response_as_enum | |
| response = llm_chain.invoke({"query": user_query}) | |
| return {"task_type": response.value} | |
| def route(state: State): | |
| task_type = state.get("task_type") | |
| return task_type | |
| def format_results(search_results: Dict[str, Any]): | |
| context = [] | |
| results = search_results["results"] | |
| for result in results: | |
| if result["score"] > 0.5: | |
| context.append(result["content"]) | |
| context = "\n\n".join(context) | |
| return {"query": search_results["query"], "context": context} | |
| def handle_inquiry(state: State): | |
| prompt_template = ( | |
| "Given the following context:{context}" | |
| "Answer the user's query:{query}" | |
| "\n\nAssistant:" | |
| ) | |
| prompt = PromptTemplate.from_template(prompt_template) | |
| rag_chain = TavilySearchTool | format_results | prompt | llm | StrOutputParser() | fetch_assistant_response | |
| assistant_response = rag_chain.invoke({"query": state.get("query")}) | |
| return {"messages": [{"role": "user", "content": state.get("query")}, | |
| {"role": "assistant", "content": assistant_response}]} | |
| def handle_medical_task(state: State): | |
| return {"messages": [{"role": "user", "content": state.get("query")}, | |
| {"role": "assistant", "content": "This is a medical task. Further implementation needed."}]} | |
| # Graph nodes | |
| CLASSIFY_TASK = "classify_task" | |
| HANDLE_INQUIRY = "handle_inquiry" | |
| HANDLE_MEDICAL_TASK = "handle_medical_task" | |
| graph_builder.add_node(CLASSIFY_TASK, classify_task) | |
| graph_builder.add_node(HANDLE_INQUIRY, handle_inquiry) | |
| graph_builder.add_node(HANDLE_MEDICAL_TASK, handle_medical_task) | |
| graph_builder.add_edge(START, CLASSIFY_TASK) | |
| graph_builder.add_conditional_edges(CLASSIFY_TASK, route, { | |
| "inquiry": HANDLE_INQUIRY, "medical task": HANDLE_MEDICAL_TASK | |
| }) | |
| graph_builder.add_edge(HANDLE_INQUIRY, END) | |
| graph_builder.add_edge(HANDLE_MEDICAL_TASK, END) | |
| compiled_graph = graph_builder.compile() | |
| # Gradio interface function | |
| def gradio_agent(query): | |
| state = {"query": query, "messages": []} | |
| for chunk in compiled_graph.stream({"query": query}): | |
| if "handle_inquiry" in chunk: | |
| return chunk["handle_inquiry"]["messages"][-1]["content"] | |
| return "No response generated." | |
| iface = gr.Interface(fn=gradio_agent, inputs="text", outputs="text", title="LangGraph Medical Agent", | |
| description="Ask a question, and the agent will classify and answer appropriately.") | |
| if __name__ == "__main__": | |
| iface.launch() | |