Spaces:
Sleeping
Sleeping
feat: Implement core gazette parsing, data extraction, and foundational AI auditor system components.
f0c339c | # LangGraph Chat Agent for DocuTrace AI Auditor | |
| """ | |
| A LangGraph-based agent that uses database query tools to answer | |
| high-precision questions about Kuwait Gazette data. | |
| """ | |
| import os | |
| import json | |
| import logging | |
| from typing import Optional, List, Dict, Any, Literal, Annotated | |
| from datetime import datetime | |
| from pydantic import BaseModel, Field | |
| from langchain_openai import ChatOpenAI | |
| from langchain.tools import tool | |
| from langgraph.graph import StateGraph, MessagesState, START, END | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langgraph.checkpoint.memory import InMemorySaver | |
| from sqlalchemy.orm import Session | |
| from sqlalchemy import func | |
| from config import get_available_models, get_default_model, ModelProvider | |
| from database import get_db_session | |
| from models import ( | |
| GazetteIssue, Tender, CommercialAgency, GeneralAssembly, | |
| BankruptcyCase, Decree, CompanyAnnouncement | |
| ) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================ | |
| # TOOL SCHEMAS (Pydantic models for structured input) | |
| # ============================================================================ | |
| class CountEntitiesInput(BaseModel): | |
| """Input schema for counting entities.""" | |
| entity_type: Literal["tender", "agency", "decree", "assembly", "bankruptcy", "announcement"] = Field( | |
| description="Type of entity to count" | |
| ) | |
| ministry: Optional[str] = Field(default=None, description="Filter by ministry (for tenders/decrees)") | |
| country: Optional[str] = Field(default=None, description="Filter by country (for agencies)") | |
| class ListEntitiesInput(BaseModel): | |
| """Input schema for listing entities.""" | |
| entity_type: Literal["tender", "agency", "decree", "assembly", "bankruptcy", "announcement"] = Field( | |
| description="Type of entity to list" | |
| ) | |
| limit: int = Field(default=10, description="Maximum number of results to return") | |
| ministry: Optional[str] = Field(default=None, description="Filter by ministry") | |
| country: Optional[str] = Field(default=None, description="Filter by country (for agencies)") | |
| search_term: Optional[str] = Field(default=None, description="Search keyword in text/subject") | |
| class SearchTendersInput(BaseModel): | |
| """Input schema for searching tenders.""" | |
| keyword: Optional[str] = Field(default=None, description="Keyword to search in tender subject") | |
| ministry: Optional[str] = Field(default=None, description="Filter by ministry name") | |
| is_postponed: Optional[bool] = Field(default=None, description="Filter by postponement status") | |
| limit: int = Field(default=10, description="Maximum results to return") | |
| class SearchAgenciesInput(BaseModel): | |
| """Input schema for searching commercial agencies.""" | |
| country: Optional[str] = Field(default=None, description="Filter by principal's country") | |
| agent_name: Optional[str] = Field(default=None, description="Search by agent company name") | |
| activity: Optional[str] = Field(default=None, description="Search by activity description") | |
| limit: int = Field(default=10, description="Maximum results to return") | |
| # ============================================================================ | |
| # DATABASE TOOLS | |
| # ============================================================================ | |
| def get_model_map(): | |
| """Get mapping of entity types to SQLAlchemy models.""" | |
| return { | |
| "tender": Tender, | |
| "agency": CommercialAgency, | |
| "decree": Decree, | |
| "assembly": GeneralAssembly, | |
| "bankruptcy": BankruptcyCase, | |
| "announcement": CompanyAnnouncement | |
| } | |
| def count_entities( | |
| entity_type: str, | |
| ministry: Optional[str] = None, | |
| country: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Count the number of entities of a specific type in the database. | |
| Use this for questions like "How many tenders are there?" or "Count agencies from Turkey". | |
| Returns a single count value. | |
| """ | |
| model_map = get_model_map() | |
| model = model_map.get(entity_type) | |
| if not model: | |
| return f"Unknown entity type: {entity_type}" | |
| session = get_db_session() | |
| try: | |
| query = session.query(func.count(model.id)) | |
| # Apply filters | |
| if ministry and hasattr(model, 'ministry'): | |
| query = query.filter(model.ministry.ilike(f"%{ministry}%")) | |
| if country and hasattr(model, 'principal_country'): | |
| query = query.filter(model.principal_country.ilike(f"%{country}%")) | |
| count = query.scalar() | |
| # Build response | |
| filter_desc = [] | |
| if ministry: | |
| filter_desc.append(f"ministry containing '{ministry}'") | |
| if country: | |
| filter_desc.append(f"country '{country}'") | |
| if filter_desc: | |
| return f"COUNT_RESULT: {count} {entity_type}(s) matching {', '.join(filter_desc)}" | |
| return f"COUNT_RESULT: {count} {entity_type}(s) total in the database" | |
| finally: | |
| session.close() | |
| def list_entities( | |
| entity_type: str, | |
| limit: int = 10, | |
| ministry: Optional[str] = None, | |
| country: Optional[str] = None, | |
| search_term: Optional[str] = None | |
| ) -> str: | |
| """ | |
| List entities of a specific type from the database. | |
| Use this for questions like "Show me all tenders" or "List agencies from Germany". | |
| Returns a JSON table of results. | |
| """ | |
| model_map = get_model_map() | |
| model = model_map.get(entity_type) | |
| if not model: | |
| return f"Unknown entity type: {entity_type}" | |
| session = get_db_session() | |
| try: | |
| query = session.query(model) | |
| # Apply filters | |
| if ministry and hasattr(model, 'ministry'): | |
| query = query.filter(model.ministry.ilike(f"%{ministry}%")) | |
| if country and hasattr(model, 'principal_country'): | |
| query = query.filter(model.principal_country.ilike(f"%{country}%")) | |
| if search_term: | |
| if hasattr(model, 'subject'): | |
| query = query.filter(model.subject.ilike(f"%{search_term}%")) | |
| elif hasattr(model, 'activity_description'): | |
| query = query.filter(model.activity_description.ilike(f"%{search_term}%")) | |
| elif hasattr(model, 'title'): | |
| query = query.filter(model.title.ilike(f"%{search_term}%")) | |
| results = query.limit(limit).all() | |
| # Convert to list of dicts | |
| entities = [] | |
| for row in results: | |
| entity = {} | |
| for column in row.__table__.columns: | |
| value = getattr(row, column.name) | |
| if isinstance(value, datetime): | |
| value = value.isoformat() | |
| # Skip large text fields | |
| if column.name not in ['full_text', 'raw_content']: | |
| entity[column.name] = value | |
| entities.append(entity) | |
| if not entities: | |
| return f"TABLE_RESULT: No {entity_type}(s) found matching your criteria." | |
| return f"TABLE_RESULT: {json.dumps(entities, indent=2)}" | |
| finally: | |
| session.close() | |
| def search_tenders( | |
| keyword: Optional[str] = None, | |
| ministry: Optional[str] = None, | |
| is_postponed: Optional[bool] = None, | |
| limit: int = 10 | |
| ) -> str: | |
| """ | |
| Search for tenders and practices with specific criteria. | |
| Use this for questions like "Find construction tenders" or "Show postponed tenders". | |
| Returns a JSON table of matching tenders. | |
| """ | |
| session = get_db_session() | |
| try: | |
| query = session.query(Tender) | |
| if keyword: | |
| query = query.filter(Tender.subject.ilike(f"%{keyword}%")) | |
| if ministry: | |
| query = query.filter(Tender.ministry.ilike(f"%{ministry}%")) | |
| if is_postponed is not None: | |
| query = query.filter(Tender.is_postponed == is_postponed) | |
| results = query.limit(limit).all() | |
| tenders = [] | |
| for t in results: | |
| tenders.append({ | |
| "id": t.id, | |
| "tender_number": t.tender_number, | |
| "tender_type": t.tender_type, | |
| "subject": t.subject[:100] if t.subject else None, | |
| "ministry": t.ministry, | |
| "is_postponed": t.is_postponed, | |
| "closing_date": t.closing_date.isoformat() if t.closing_date else None | |
| }) | |
| if not tenders: | |
| return "TABLE_RESULT: No tenders found matching your criteria." | |
| return f"TABLE_RESULT: {json.dumps(tenders, indent=2)}" | |
| finally: | |
| session.close() | |
| def search_agencies( | |
| country: Optional[str] = None, | |
| agent_name: Optional[str] = None, | |
| activity: Optional[str] = None, | |
| limit: int = 10 | |
| ) -> str: | |
| """ | |
| Search for commercial agency registrations. | |
| Use this for questions like "Show agencies from Turkey" or "Find agencies for construction". | |
| Returns a JSON table of matching agencies. | |
| """ | |
| session = get_db_session() | |
| try: | |
| query = session.query(CommercialAgency) | |
| if country: | |
| query = query.filter(CommercialAgency.principal_country.ilike(f"%{country}%")) | |
| if agent_name: | |
| query = query.filter(CommercialAgency.agent_name.ilike(f"%{agent_name}%")) | |
| if activity: | |
| query = query.filter(CommercialAgency.activity_description.ilike(f"%{activity}%")) | |
| results = query.limit(limit).all() | |
| agencies = [] | |
| for a in results: | |
| agencies.append({ | |
| "id": a.id, | |
| "registration_number": a.registration_number, | |
| "agent_name": a.agent_name, | |
| "principal_name": a.principal_name, | |
| "principal_country": a.principal_country, | |
| "activity": a.activity_description[:100] if a.activity_description else None, | |
| "start_date": a.start_date.isoformat() if a.start_date else None, | |
| "end_date": a.end_date.isoformat() if a.end_date else None | |
| }) | |
| if not agencies: | |
| return "TABLE_RESULT: No agencies found matching your criteria." | |
| return f"TABLE_RESULT: {json.dumps(agencies, indent=2)}" | |
| finally: | |
| session.close() | |
| def get_database_stats() -> str: | |
| """ | |
| Get overall statistics about the gazette database. | |
| Use this for questions like "What's in the database?" or "Give me an overview". | |
| Returns a text summary of all entity counts. | |
| """ | |
| session = get_db_session() | |
| try: | |
| stats = { | |
| "gazette_issues": session.query(func.count(GazetteIssue.id)).scalar(), | |
| "tenders": session.query(func.count(Tender.id)).scalar(), | |
| "commercial_agencies": session.query(func.count(CommercialAgency.id)).scalar(), | |
| "decrees": session.query(func.count(Decree.id)).scalar(), | |
| "general_assemblies": session.query(func.count(GeneralAssembly.id)).scalar(), | |
| "bankruptcy_cases": session.query(func.count(BankruptcyCase.id)).scalar(), | |
| "company_announcements": session.query(func.count(CompanyAnnouncement.id)).scalar() | |
| } | |
| total = sum(v for k, v in stats.items() if k != "gazette_issues") | |
| summary = f"""TEXT_RESULT: Database Statistics: | |
| 📰 Gazette Issues: {stats['gazette_issues']} | |
| 📋 Tenders & Practices: {stats['tenders']} | |
| 🤝 Commercial Agencies: {stats['commercial_agencies']} | |
| 📜 Decrees: {stats['decrees']} | |
| 🏢 General Assemblies: {stats['general_assemblies']} | |
| ⚖️ Bankruptcy Cases: {stats['bankruptcy_cases']} | |
| 📢 Company Announcements: {stats['company_announcements']} | |
| ━━━━━━━━━━━━━━━━━━━━━ | |
| Total Entities: {total}""" | |
| return summary | |
| finally: | |
| session.close() | |
| # ============================================================================ | |
| # LANGGRAPH AGENT | |
| # ============================================================================ | |
| # All tools available to the agent | |
| ALL_TOOLS = [ | |
| count_entities, | |
| list_entities, | |
| search_tenders, | |
| search_agencies, | |
| get_database_stats | |
| ] | |
| def create_agent_graph(model_config=None): | |
| """ | |
| Create the LangGraph agent with tools. | |
| Args: | |
| model_config: Optional model configuration from config.py | |
| Returns: | |
| Compiled LangGraph graph | |
| """ | |
| # Get model config | |
| if model_config is None: | |
| model_config = get_default_model() | |
| if model_config is None: | |
| raise ValueError("No model configured. Set GOOGLE_API_KEY or OPENROUTER_API_KEY") | |
| # Create LLM based on provider | |
| if model_config.provider == ModelProvider.GEMINI: | |
| llm = ChatOpenAI( | |
| model=model_config.model_id, | |
| api_key=model_config.get_api_key(), | |
| base_url="https://generativelanguage.googleapis.com/v1beta/openai/", | |
| temperature=0.1, | |
| max_tokens=2000 | |
| ) | |
| else: | |
| # OpenRouter | |
| llm = ChatOpenAI( | |
| model=model_config.model_id, | |
| api_key=model_config.get_api_key(), | |
| base_url=model_config.base_url, | |
| temperature=0.1, | |
| max_tokens=2000, | |
| default_headers={ | |
| "HTTP-Referer": "https://doctrace.ai", | |
| "X-Title": "DocuTrace AI Auditor" | |
| } | |
| ) | |
| # Bind tools to LLM | |
| llm_with_tools = llm.bind_tools(ALL_TOOLS) | |
| # System prompt | |
| SYSTEM_PROMPT = """You are Alda, an AI assistant specialized in analyzing Kuwait Al Youm gazette documents. | |
| You help users understand government decrees, tenders, commercial agencies, bankruptcy cases, and company announcements. | |
| IMPORTANT INSTRUCTIONS: | |
| 1. Use the provided tools to query the database for accurate information. | |
| 2. For counting questions (e.g., "how many"), use the count_entities tool. | |
| 3. For listing questions (e.g., "show me", "list"), use list_entities or specific search tools. | |
| 4. For overview questions, use get_database_stats tool. | |
| 5. Always call a tool when the question is about data in the gazette. | |
| 6. Provide clear, concise answers based on the tool results. | |
| 7. If a tool returns TABLE_RESULT, format the data nicely in your response. | |
| 8. If a tool returns COUNT_RESULT, present the number clearly. | |
| 9. If a tool returns TEXT_RESULT, present the summary directly. | |
| Available entity types: tender, agency, decree, assembly, bankruptcy, announcement""" | |
| # Define the agent node | |
| def agent_node(state: MessagesState): | |
| """Process messages and decide whether to use tools.""" | |
| messages = state["messages"] | |
| # Add system prompt if this is the first message | |
| system_message = {"role": "system", "content": SYSTEM_PROMPT} | |
| full_messages = [system_message] + messages | |
| response = llm_with_tools.invoke(full_messages) | |
| return {"messages": [response]} | |
| # Build the graph | |
| graph = StateGraph(MessagesState) | |
| # Add nodes | |
| graph.add_node("agent", agent_node) | |
| graph.add_node("tools", ToolNode(ALL_TOOLS)) | |
| # Add edges | |
| graph.add_edge(START, "agent") | |
| graph.add_conditional_edges( | |
| "agent", | |
| tools_condition, # Built-in condition that checks for tool calls | |
| { | |
| "tools": "tools", | |
| END: END | |
| } | |
| ) | |
| graph.add_edge("tools", "agent") # After tool execution, go back to agent | |
| # Compile with memory | |
| memory = InMemorySaver() | |
| compiled_graph = graph.compile(checkpointer=memory) | |
| logger.info(f"Created LangGraph agent with model: {model_config.display_name}") | |
| return compiled_graph | |
| def parse_tool_result(content: str) -> Dict[str, Any]: | |
| """ | |
| Parse tool result to extract type and data. | |
| Returns: | |
| Dict with 'type' (value/table/text) and 'data' | |
| """ | |
| if content.startswith("COUNT_RESULT:"): | |
| # Extract the count number | |
| text = content.replace("COUNT_RESULT:", "").strip() | |
| return {"type": "value", "data": text} | |
| elif content.startswith("TABLE_RESULT:"): | |
| text = content.replace("TABLE_RESULT:", "").strip() | |
| try: | |
| # Try to parse as JSON | |
| data = json.loads(text) | |
| return {"type": "table", "data": data} | |
| except json.JSONDecodeError: | |
| return {"type": "text", "data": text} | |
| elif content.startswith("TEXT_RESULT:"): | |
| text = content.replace("TEXT_RESULT:", "").strip() | |
| return {"type": "text", "data": text} | |
| return {"type": "text", "data": content} | |
| class AgentExecutor: | |
| """ | |
| High-level interface for the LangGraph agent. | |
| Manages the agent lifecycle and provides a simple query interface. | |
| """ | |
| def __init__(self, model_config=None): | |
| """Initialize the agent executor.""" | |
| self.graph = create_agent_graph(model_config) | |
| self.thread_id = "default" | |
| self.model_config = model_config or get_default_model() | |
| def set_thread(self, thread_id: str): | |
| """Set the conversation thread ID for memory.""" | |
| self.thread_id = thread_id | |
| def query(self, user_message: str) -> Dict[str, Any]: | |
| """ | |
| Send a query to the agent and get a response. | |
| Args: | |
| user_message: The user's question | |
| Returns: | |
| Dict with: | |
| - response: The final text response | |
| - tool_calls: List of tools that were called | |
| - output_type: 'value', 'table', or 'text' | |
| - data: Structured data if applicable | |
| """ | |
| config = {"configurable": {"thread_id": self.thread_id}} | |
| try: | |
| # Invoke the graph | |
| result = self.graph.invoke( | |
| {"messages": [{"role": "user", "content": user_message}]}, | |
| config | |
| ) | |
| messages = result["messages"] | |
| # Extract tool calls from the conversation | |
| tool_calls = [] | |
| tool_results = [] | |
| for msg in messages: | |
| if hasattr(msg, 'tool_calls') and msg.tool_calls: | |
| for tc in msg.tool_calls: | |
| tool_calls.append({ | |
| "name": tc["name"], | |
| "args": tc["args"] | |
| }) | |
| # Check for tool messages | |
| if hasattr(msg, 'type') and msg.type == "tool": | |
| tool_results.append(msg.content) | |
| # Get the final AI response | |
| final_response = messages[-1].content if messages else "No response generated." | |
| # Determine output type from tool results | |
| output_type = "text" | |
| structured_data = None | |
| for result_content in tool_results: | |
| parsed = parse_tool_result(result_content) | |
| if parsed["type"] == "table" and isinstance(parsed["data"], list): | |
| output_type = "table" | |
| structured_data = parsed["data"] | |
| break | |
| elif parsed["type"] == "value": | |
| output_type = "value" | |
| structured_data = parsed["data"] | |
| return { | |
| "response": final_response, | |
| "tool_calls": tool_calls, | |
| "output_type": output_type, | |
| "data": structured_data | |
| } | |
| except Exception as e: | |
| logger.exception(f"Agent query error: {e}") | |
| return { | |
| "response": f"I encountered an error processing your request: {str(e)}", | |
| "tool_calls": [], | |
| "output_type": "text", | |
| "data": None | |
| } | |
| # Create a global agent instance (lazy initialization) | |
| _agent_instance: Optional[AgentExecutor] = None | |
| def get_agent(model_config=None) -> AgentExecutor: | |
| """Get or create the global agent instance.""" | |
| global _agent_instance | |
| if _agent_instance is None: | |
| _agent_instance = AgentExecutor(model_config) | |
| return _agent_instance | |
| def reset_agent(model_config=None): | |
| """Reset the agent (useful when changing models).""" | |
| global _agent_instance | |
| _agent_instance = AgentExecutor(model_config) | |
| return _agent_instance | |