Spaces:
Running
Running
| # workflow.py | |
| import time | |
| from datetime import datetime | |
| from typing import Dict, Any, Sequence | |
| from langchain_core.messages import AIMessage, HumanMessage, ToolMessage | |
| from langgraph.graph import END, StateGraph | |
| from langgraph.graph.message import add_messages | |
| from typing_extensions import TypedDict, Annotated | |
| from processor import EnhancedCognitiveProcessor | |
| from config import ResearchConfig | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Define the state schema | |
| class AgentState(TypedDict): | |
| messages: Annotated[Sequence[AIMessage | HumanMessage | ToolMessage], add_messages] | |
| context: Dict[str, Any] | |
| metadata: Dict[str, Any] | |
| class ResearchWorkflow: | |
| """ | |
| A multi-step research workflow employing Retrieval-Augmented Generation (RAG) with an additional verification step. | |
| This workflow supports multiple domains (e.g., Biomedical, Legal, Environmental, Competitive Programming, Social Sciences) | |
| and integrates domain-specific prompts, iterative refinement, and a final verification to reduce hallucinations. | |
| """ | |
| def __init__(self) -> None: | |
| self.processor = EnhancedCognitiveProcessor() | |
| self.workflow = StateGraph(AgentState) | |
| self._build_workflow() | |
| self.app = self.workflow.compile() | |
| def _build_workflow(self) -> None: | |
| self.workflow.add_node("ingest", self.ingest_query) | |
| self.workflow.add_node("retrieve", self.retrieve_documents) | |
| self.workflow.add_node("analyze", self.analyze_content) | |
| self.workflow.add_node("validate", self.validate_output) | |
| self.workflow.add_node("refine", self.refine_results) | |
| # New verify node to further cross-check the output | |
| self.workflow.add_node("verify", self.verify_output) | |
| self.workflow.set_entry_point("ingest") | |
| self.workflow.add_edge("ingest", "retrieve") | |
| self.workflow.add_edge("retrieve", "analyze") | |
| self.workflow.add_conditional_edges( | |
| "analyze", | |
| self._quality_check, | |
| {"valid": "validate", "invalid": "refine"} | |
| ) | |
| self.workflow.add_edge("validate", "verify") | |
| self.workflow.add_edge("refine", "retrieve") | |
| # Extended node for multi-modal enhancement | |
| self.workflow.add_node("enhance", self.enhance_analysis) | |
| self.workflow.add_edge("verify", "enhance") | |
| self.workflow.add_edge("enhance", END) | |
| def ingest_query(self, state: Dict) -> Dict: | |
| try: | |
| query = state["messages"][-1].content | |
| # Normalize the domain string; default to 'biomedical research' | |
| domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower() | |
| new_context = { | |
| "raw_query": query, | |
| "domain": domain, | |
| "refine_count": 0, | |
| "refinement_history": [] | |
| } | |
| logger.info(f"Query ingested. Domain: {domain}") | |
| return { | |
| "messages": [AIMessage(content="Query ingested successfully")], | |
| "context": new_context, | |
| "metadata": {"timestamp": datetime.now().isoformat()} | |
| } | |
| except Exception as e: | |
| logger.exception("Error during query ingestion.") | |
| return self._error_state(f"Ingestion Error: {str(e)}") | |
| def retrieve_documents(self, state: Dict) -> Dict: | |
| try: | |
| query = state["context"]["raw_query"] | |
| # Placeholder retrieval: currently returns an empty list (simulate no documents) | |
| docs = [] | |
| logger.info(f"Retrieved {len(docs)} documents for query.") | |
| return { | |
| "messages": [AIMessage(content=f"Retrieved {len(docs)} documents")], | |
| "context": { | |
| "documents": docs, | |
| "retrieval_time": time.time(), | |
| "refine_count": state["context"].get("refine_count", 0), | |
| "refinement_history": state["context"].get("refinement_history", []), | |
| "domain": state["context"].get("domain", "biomedical research") | |
| } | |
| } | |
| except Exception as e: | |
| logger.exception("Error during document retrieval.") | |
| return self._error_state(f"Retrieval Error: {str(e)}") | |
| def analyze_content(self, state: Dict) -> Dict: | |
| try: | |
| domain = state["context"].get("domain", "biomedical research").strip().lower() | |
| docs = state["context"].get("documents", []) | |
| if docs: | |
| docs_text = "\n\n".join([d.page_content for d in docs]) | |
| else: | |
| docs_text = state["context"].get("raw_query", "") | |
| logger.info("No documents retrieved; switching to dynamic synthesis (RAG mode).") | |
| # Use domain-specific prompt; for legal research, inject legal-specific guidance. | |
| domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain, | |
| "Provide an analysis based on the provided context.") | |
| full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \ | |
| f"{domain_prompt}\n\n" + \ | |
| ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text) | |
| response = self.processor.process_query(full_prompt) | |
| if "error" in response: | |
| logger.error("Backend response error during analysis.") | |
| return self._error_state(response["error"]) | |
| logger.info("Content analysis completed using RAG approach.") | |
| return { | |
| "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))], | |
| "context": state["context"] | |
| } | |
| except Exception as e: | |
| logger.exception("Error during content analysis.") | |
| return self._error_state(f"Analysis Error: {str(e)}") | |
| def validate_output(self, state: Dict) -> Dict: | |
| try: | |
| analysis = state["messages"][-1].content | |
| validation_prompt = ( | |
| f"Validate the following analysis for accuracy and domain-specific relevance:\n{analysis}\n\n" | |
| "Criteria:\n" | |
| "1. Factual and technical accuracy\n" | |
| "2. For legal research: inclusion of relevant precedents and statutory interpretations; " | |
| "for other domains: appropriate domain insights\n" | |
| "3. Logical consistency\n" | |
| "4. Methodological soundness\n\n" | |
| "Respond with 'VALID: [justification]' or 'INVALID: [justification]'." | |
| ) | |
| response = self.processor.process_query(validation_prompt) | |
| logger.info("Output validation completed.") | |
| return { | |
| "messages": [AIMessage(content=analysis + f"\n\nValidation: {response.get('choices', [{}])[0].get('message', {}).get('content', '')}")], | |
| "context": state["context"] | |
| } | |
| except Exception as e: | |
| logger.exception("Error during output validation.") | |
| return self._error_state(f"Validation Error: {str(e)}") | |
| def verify_output(self, state: Dict) -> Dict: | |
| try: | |
| # New verify step: cross-check the analysis using an external fact-checking prompt. | |
| analysis = state["messages"][-1].content | |
| verification_prompt = ( | |
| f"Verify the following analysis by comparing it with established external legal databases and reference texts:\n{analysis}\n\n" | |
| "Identify any discrepancies or hallucinations and provide a brief correction if necessary." | |
| ) | |
| response = self.processor.process_query(verification_prompt) | |
| logger.info("Output verification completed.") | |
| # Here, you can merge the verification feedback with the analysis. | |
| verified_analysis = analysis + "\n\nVerification Feedback: " + response.get('choices', [{}])[0].get('message', {}).get('content', '') | |
| return { | |
| "messages": [AIMessage(content=verified_analysis)], | |
| "context": state["context"] | |
| } | |
| except Exception as e: | |
| logger.exception("Error during output verification.") | |
| return self._error_state(f"Verification Error: {str(e)}") | |
| def refine_results(self, state: Dict) -> Dict: | |
| try: | |
| current_count = state["context"].get("refine_count", 0) | |
| state["context"]["refine_count"] = current_count + 1 | |
| refinement_history = state["context"].setdefault("refinement_history", []) | |
| current_analysis = state["messages"][-1].content | |
| refinement_history.append(current_analysis) | |
| difficulty_level = max(0, 3 - state["context"]["refine_count"]) | |
| domain = state["context"].get("domain", "biomedical research") | |
| logger.info(f"Refinement iteration: {state['context']['refine_count']}, Difficulty level: {difficulty_level}") | |
| if state["context"]["refine_count"] >= 3: | |
| meta_prompt = ( | |
| f"Domain: {domain}\n" | |
| "You are given the following series of refinement outputs:\n" + | |
| "\n---\n".join(refinement_history) + | |
| "\n\nSynthesize these into a final, concise analysis report with improved accuracy and verifiable details." | |
| ) | |
| meta_response = self.processor.process_query(meta_prompt) | |
| logger.info("Meta-refinement completed.") | |
| return { | |
| "messages": [AIMessage(content=meta_response.get('choices', [{}])[0].get('message', {}).get('content', ''))], | |
| "context": state["context"] | |
| } | |
| else: | |
| refinement_prompt = ( | |
| f"Domain: {domain}\n" | |
| f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n" | |
| "Identify and correct any weaknesses or hallucinations in the analysis, providing verifiable details." | |
| ) | |
| response = self.processor.process_query(refinement_prompt) | |
| logger.info("Refinement completed.") | |
| return { | |
| "messages": [AIMessage(content=response.get('choices', [{}])[0].get('message', {}).get('content', ''))], | |
| "context": state["context"] | |
| } | |
| except Exception as e: | |
| logger.exception("Error during refinement.") | |
| return self._error_state(f"Refinement Error: {str(e)}") | |
| def _quality_check(self, state: Dict) -> str: | |
| refine_count = state["context"].get("refine_count", 0) | |
| if refine_count >= 3: | |
| logger.warning("Refinement limit reached. Forcing valid outcome.") | |
| return "valid" | |
| content = state["messages"][-1].content | |
| quality = "valid" if "VALID" in content else "invalid" | |
| logger.info(f"Quality check returned: {quality}") | |
| return quality | |
| def _error_state(self, message: str) -> Dict: | |
| logger.error(message) | |
| return { | |
| "messages": [AIMessage(content=f"❌ {message}")], | |
| "context": {"error": True}, | |
| "metadata": {"status": "error"} | |
| } | |
| def enhance_analysis(self, state: Dict) -> Dict: | |
| try: | |
| analysis = state["messages"][-1].content | |
| enhanced = f"{analysis}\n\n## Multi-Modal Insights\n" | |
| if "images" in state["context"]: | |
| enhanced += "### Visual Evidence\n" | |
| for img in state["context"]["images"]: | |
| enhanced += f"\n" | |
| if "code" in state["context"]: | |
| enhanced += "### Code Artifacts\n```python\n" | |
| for code in state["context"]["code"]: | |
| enhanced += f"{code}\n" | |
| enhanced += "```" | |
| return { | |
| "messages": [AIMessage(content=enhanced)], | |
| "context": state["context"] | |
| } | |
| except Exception as e: | |
| logger.exception("Error during multi-modal enhancement.") | |
| return self._error_state(f"Enhancement Error: {str(e)}") | |