Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, List, Any, Optional | |
| from langchain_anthropic import ChatAnthropic | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from typing import TypedDict, List, Optional, Union | |
| import copy | |
| from tools.retriever_tool import DocumentRetriever | |
| from tools.search_tool import WebSearchTool | |
| from tools.pdf_tool import PDFProcessor | |
| class AgentState(TypedDict): | |
| """State schema for the agent.""" | |
| messages: List[Union[HumanMessage, AIMessage]] | |
| query: str | |
| csv_results: Optional[str] | |
| web_results: Optional[str] | |
| pdf_results: Optional[str] | |
| response: Optional[str] | |
| class MedTranscriptAgent: | |
| def __init__(self, anthropic_api_key: Optional[str] = None, debug: bool = False): | |
| self.api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY") | |
| if not self.api_key: | |
| raise ValueError("Anthropic API key is required") | |
| self.llm = ChatAnthropic( | |
| model="claude-3-7-sonnet-20250219", | |
| anthropic_api_key=self.api_key, | |
| temperature=0.1 | |
| ) | |
| self.doc_retriever = DocumentRetriever() | |
| self.web_search = WebSearchTool(debug=debug) | |
| self.pdf_processor = PDFProcessor() | |
| self.debug = debug | |
| self.memory_store = MemorySaver() | |
| self.conversation_threads = {} | |
| self.graph = self._build_graph() | |
| def _build_graph(self) -> StateGraph: | |
| """Build the LangGraph workflow for conversational QA""" | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("query_router", self._route_query) | |
| workflow.add_node("document_search", self._perform_doc_search) | |
| workflow.add_node("web_search", self._perform_web_search) | |
| workflow.add_node("pdf_search", self._perform_pdf_search) | |
| workflow.add_node("combine_results", self._generate_response) | |
| workflow.add_edge(START, "query_router") | |
| workflow.add_edge("query_router", "document_search") | |
| workflow.add_edge("query_router", "web_search") | |
| workflow.add_edge("query_router", "pdf_search") | |
| workflow.add_edge("document_search", "combine_results") | |
| workflow.add_edge("web_search", "combine_results") | |
| workflow.add_edge("pdf_search", "combine_results") | |
| workflow.add_edge("combine_results", END) | |
| return workflow.compile(checkpointer=self.memory_store) | |
| def _route_query(self, state: AgentState) -> Dict[str, Any]: | |
| """Determine which tool(s) to use for the query""" | |
| query = state["query"] | |
| messages = state.get("messages", []) | |
| if self.debug: | |
| print(f"[Router] Processing query with {len(messages)} existing messages") | |
| conversation_history = self._format_conversation_history(messages) | |
| routing_prompt = f""" | |
| You are a medical query router. Your job is to determine whether a query about medical topics should be: | |
| 1. Answered using document search (for specific patient data or medical transcript information) | |
| 2. Answered using web search (for general medical knowledge) | |
| 3. Answered using PDF search (for detailed medical protocol or research documents) | |
| Consider the conversation history and the current query when making your decision. | |
| Conversation history: | |
| {conversation_history} | |
| Current query: {query} | |
| Respond with one or more of: "document", "web", "pdf" | |
| """ | |
| route = self.llm.invoke(routing_prompt).content.strip().lower() | |
| if self.debug: | |
| print(f"[Router] Decision: {route}") | |
| next_steps = [] | |
| if "document" in route: | |
| next_steps.append("document_search") | |
| if "web" in route: | |
| next_steps.append("web_search") | |
| if "pdf" in route: | |
| next_steps.append("pdf_search") | |
| if not next_steps: | |
| next_steps = ["document_search", "web_search", "pdf_search"] | |
| return {"next": next_steps} | |
| def _perform_doc_search(self, state: AgentState) -> Dict[str, Any]: | |
| """Perform document search and return results""" | |
| query = state["query"] | |
| if self.debug: | |
| print(f"[Document Search] Searching for: {query}") | |
| results = self.doc_retriever.query(query) | |
| return {"csv_results": results} | |
| def _perform_web_search(self, state: AgentState) -> Dict[str, Any]: | |
| """Perform web search and return results""" | |
| query = state["query"] | |
| if self.debug: | |
| print(f"[Web Search] Searching for: {query}") | |
| results = self.web_search.search(query) | |
| return {"web_results": results} | |
| def _perform_pdf_search(self, state: AgentState) -> Dict[str, Any]: | |
| """Perform PDF search and return results""" | |
| query = state["query"] | |
| if self.debug: | |
| print(f"[PDF Search] Searching for: {query}") | |
| results = self.pdf_processor.search(query) | |
| return {"pdf_results": results} | |
| def _generate_response(self, state: AgentState) -> Dict[str, Any]: | |
| """Generate a response based on search results and conversation history""" | |
| query = state["query"] | |
| messages = state.get("messages", []) | |
| if self.debug: | |
| print(f"[Generate Response] Processing with {len(messages)} messages in history") | |
| csv_results = state.get("csv_results", "No document results available") | |
| web_results = state.get("web_results", "No web results available") | |
| pdf_results = state.get("pdf_results", "No PDF results available") | |
| conversation_history = self._format_conversation_history(messages) | |
| response_prompt = f""" | |
| You are a helpful medical assistant answering questions about medical transcripts and general medical knowledge. | |
| You have access to three types of information sources: medical transcripts (CSV), web search results, and PDF documents. | |
| Conversation history: | |
| {conversation_history} | |
| Current query: {query} | |
| Document search results: {csv_results} | |
| Web search results: {web_results} | |
| PDF search results: {pdf_results} | |
| Based on all available information and your medical knowledge, provide a helpful, accurate, and compassionate response to the query. | |
| Make sure to consider the conversation history for context and continuity. | |
| When citing information, clearly indicate the source (Document, Web, or PDF). | |
| """ | |
| response = self.llm.invoke(response_prompt).content | |
| updated_messages = messages + [ | |
| HumanMessage(content=query), | |
| AIMessage(content=response) | |
| ] | |
| if self.debug: | |
| print(f"[Generate Response] History now has {len(updated_messages)} messages") | |
| return { | |
| "response": response, | |
| "messages": updated_messages | |
| } | |
| def _format_conversation_history(self, messages: List) -> str: | |
| """Format conversation history for inclusion in prompts""" | |
| if not messages: | |
| return "No previous conversation" | |
| formatted = [] | |
| for i in range(0, len(messages), 2): | |
| if i < len(messages): | |
| user_msg = messages[i].content if i < len(messages) else "" | |
| ai_msg = messages[i+1].content if i+1 < len(messages) else "" | |
| formatted.append(f"Human: {user_msg}\nAI: {ai_msg}") | |
| return "\n\n".join(formatted) | |
| def load_pdf(self, file_path: str) -> str: | |
| """Load a PDF document into the agent""" | |
| return self.pdf_processor.load_pdf(file_path) | |
| def chat(self, message: str, thread_id: str = "default") -> str: | |
| """Process a message in a conversation thread""" | |
| if thread_id in self.conversation_threads: | |
| messages = self.conversation_threads[thread_id] | |
| if self.debug: | |
| print(f"[Chat] Retrieved {len(messages)} messages for thread {thread_id}") | |
| else: | |
| messages = [] | |
| if self.debug: | |
| print(f"[Chat] Started new conversation thread {thread_id}") | |
| state = { | |
| "query": message, | |
| "messages": copy.deepcopy(messages) | |
| } | |
| if self.debug: | |
| print(f"[Chat] Processing query with initial state containing {len(state['messages'])} messages") | |
| try: | |
| result = self.graph.invoke( | |
| state, | |
| config={"configurable": {"thread_id": thread_id}} | |
| ) | |
| updated_messages = result.get("messages", []) | |
| self.conversation_threads[thread_id] = copy.deepcopy(updated_messages) | |
| if self.debug: | |
| print(f"[Chat] Updated thread {thread_id} with {len(updated_messages)} messages") | |
| return result["response"] | |
| except Exception as e: | |
| error_msg = f"Error processing message: {str(e)}" | |
| print(f"[ERROR] {error_msg}") | |
| return error_msg |