alda / agent.py
jameszokah's picture
feat: Implement core gazette parsing, data extraction, and foundational AI auditor system components.
f0c339c
# LangGraph Chat Agent for DocuTrace AI Auditor
"""
A LangGraph-based agent that uses database query tools to answer
high-precision questions about Kuwait Gazette data.
"""
import os
import json
import logging
from typing import Optional, List, Dict, Any, Literal, Annotated
from datetime import datetime
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain.tools import tool
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.checkpoint.memory import InMemorySaver
from sqlalchemy.orm import Session
from sqlalchemy import func
from config import get_available_models, get_default_model, ModelProvider
from database import get_db_session
from models import (
GazetteIssue, Tender, CommercialAgency, GeneralAssembly,
BankruptcyCase, Decree, CompanyAnnouncement
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# TOOL SCHEMAS (Pydantic models for structured input)
# ============================================================================
class CountEntitiesInput(BaseModel):
"""Input schema for counting entities."""
entity_type: Literal["tender", "agency", "decree", "assembly", "bankruptcy", "announcement"] = Field(
description="Type of entity to count"
)
ministry: Optional[str] = Field(default=None, description="Filter by ministry (for tenders/decrees)")
country: Optional[str] = Field(default=None, description="Filter by country (for agencies)")
class ListEntitiesInput(BaseModel):
"""Input schema for listing entities."""
entity_type: Literal["tender", "agency", "decree", "assembly", "bankruptcy", "announcement"] = Field(
description="Type of entity to list"
)
limit: int = Field(default=10, description="Maximum number of results to return")
ministry: Optional[str] = Field(default=None, description="Filter by ministry")
country: Optional[str] = Field(default=None, description="Filter by country (for agencies)")
search_term: Optional[str] = Field(default=None, description="Search keyword in text/subject")
class SearchTendersInput(BaseModel):
"""Input schema for searching tenders."""
keyword: Optional[str] = Field(default=None, description="Keyword to search in tender subject")
ministry: Optional[str] = Field(default=None, description="Filter by ministry name")
is_postponed: Optional[bool] = Field(default=None, description="Filter by postponement status")
limit: int = Field(default=10, description="Maximum results to return")
class SearchAgenciesInput(BaseModel):
"""Input schema for searching commercial agencies."""
country: Optional[str] = Field(default=None, description="Filter by principal's country")
agent_name: Optional[str] = Field(default=None, description="Search by agent company name")
activity: Optional[str] = Field(default=None, description="Search by activity description")
limit: int = Field(default=10, description="Maximum results to return")
# ============================================================================
# DATABASE TOOLS
# ============================================================================
def get_model_map():
"""Get mapping of entity types to SQLAlchemy models."""
return {
"tender": Tender,
"agency": CommercialAgency,
"decree": Decree,
"assembly": GeneralAssembly,
"bankruptcy": BankruptcyCase,
"announcement": CompanyAnnouncement
}
@tool(args_schema=CountEntitiesInput)
def count_entities(
entity_type: str,
ministry: Optional[str] = None,
country: Optional[str] = None
) -> str:
"""
Count the number of entities of a specific type in the database.
Use this for questions like "How many tenders are there?" or "Count agencies from Turkey".
Returns a single count value.
"""
model_map = get_model_map()
model = model_map.get(entity_type)
if not model:
return f"Unknown entity type: {entity_type}"
session = get_db_session()
try:
query = session.query(func.count(model.id))
# Apply filters
if ministry and hasattr(model, 'ministry'):
query = query.filter(model.ministry.ilike(f"%{ministry}%"))
if country and hasattr(model, 'principal_country'):
query = query.filter(model.principal_country.ilike(f"%{country}%"))
count = query.scalar()
# Build response
filter_desc = []
if ministry:
filter_desc.append(f"ministry containing '{ministry}'")
if country:
filter_desc.append(f"country '{country}'")
if filter_desc:
return f"COUNT_RESULT: {count} {entity_type}(s) matching {', '.join(filter_desc)}"
return f"COUNT_RESULT: {count} {entity_type}(s) total in the database"
finally:
session.close()
@tool(args_schema=ListEntitiesInput)
def list_entities(
entity_type: str,
limit: int = 10,
ministry: Optional[str] = None,
country: Optional[str] = None,
search_term: Optional[str] = None
) -> str:
"""
List entities of a specific type from the database.
Use this for questions like "Show me all tenders" or "List agencies from Germany".
Returns a JSON table of results.
"""
model_map = get_model_map()
model = model_map.get(entity_type)
if not model:
return f"Unknown entity type: {entity_type}"
session = get_db_session()
try:
query = session.query(model)
# Apply filters
if ministry and hasattr(model, 'ministry'):
query = query.filter(model.ministry.ilike(f"%{ministry}%"))
if country and hasattr(model, 'principal_country'):
query = query.filter(model.principal_country.ilike(f"%{country}%"))
if search_term:
if hasattr(model, 'subject'):
query = query.filter(model.subject.ilike(f"%{search_term}%"))
elif hasattr(model, 'activity_description'):
query = query.filter(model.activity_description.ilike(f"%{search_term}%"))
elif hasattr(model, 'title'):
query = query.filter(model.title.ilike(f"%{search_term}%"))
results = query.limit(limit).all()
# Convert to list of dicts
entities = []
for row in results:
entity = {}
for column in row.__table__.columns:
value = getattr(row, column.name)
if isinstance(value, datetime):
value = value.isoformat()
# Skip large text fields
if column.name not in ['full_text', 'raw_content']:
entity[column.name] = value
entities.append(entity)
if not entities:
return f"TABLE_RESULT: No {entity_type}(s) found matching your criteria."
return f"TABLE_RESULT: {json.dumps(entities, indent=2)}"
finally:
session.close()
@tool(args_schema=SearchTendersInput)
def search_tenders(
keyword: Optional[str] = None,
ministry: Optional[str] = None,
is_postponed: Optional[bool] = None,
limit: int = 10
) -> str:
"""
Search for tenders and practices with specific criteria.
Use this for questions like "Find construction tenders" or "Show postponed tenders".
Returns a JSON table of matching tenders.
"""
session = get_db_session()
try:
query = session.query(Tender)
if keyword:
query = query.filter(Tender.subject.ilike(f"%{keyword}%"))
if ministry:
query = query.filter(Tender.ministry.ilike(f"%{ministry}%"))
if is_postponed is not None:
query = query.filter(Tender.is_postponed == is_postponed)
results = query.limit(limit).all()
tenders = []
for t in results:
tenders.append({
"id": t.id,
"tender_number": t.tender_number,
"tender_type": t.tender_type,
"subject": t.subject[:100] if t.subject else None,
"ministry": t.ministry,
"is_postponed": t.is_postponed,
"closing_date": t.closing_date.isoformat() if t.closing_date else None
})
if not tenders:
return "TABLE_RESULT: No tenders found matching your criteria."
return f"TABLE_RESULT: {json.dumps(tenders, indent=2)}"
finally:
session.close()
@tool(args_schema=SearchAgenciesInput)
def search_agencies(
country: Optional[str] = None,
agent_name: Optional[str] = None,
activity: Optional[str] = None,
limit: int = 10
) -> str:
"""
Search for commercial agency registrations.
Use this for questions like "Show agencies from Turkey" or "Find agencies for construction".
Returns a JSON table of matching agencies.
"""
session = get_db_session()
try:
query = session.query(CommercialAgency)
if country:
query = query.filter(CommercialAgency.principal_country.ilike(f"%{country}%"))
if agent_name:
query = query.filter(CommercialAgency.agent_name.ilike(f"%{agent_name}%"))
if activity:
query = query.filter(CommercialAgency.activity_description.ilike(f"%{activity}%"))
results = query.limit(limit).all()
agencies = []
for a in results:
agencies.append({
"id": a.id,
"registration_number": a.registration_number,
"agent_name": a.agent_name,
"principal_name": a.principal_name,
"principal_country": a.principal_country,
"activity": a.activity_description[:100] if a.activity_description else None,
"start_date": a.start_date.isoformat() if a.start_date else None,
"end_date": a.end_date.isoformat() if a.end_date else None
})
if not agencies:
return "TABLE_RESULT: No agencies found matching your criteria."
return f"TABLE_RESULT: {json.dumps(agencies, indent=2)}"
finally:
session.close()
@tool
def get_database_stats() -> str:
"""
Get overall statistics about the gazette database.
Use this for questions like "What's in the database?" or "Give me an overview".
Returns a text summary of all entity counts.
"""
session = get_db_session()
try:
stats = {
"gazette_issues": session.query(func.count(GazetteIssue.id)).scalar(),
"tenders": session.query(func.count(Tender.id)).scalar(),
"commercial_agencies": session.query(func.count(CommercialAgency.id)).scalar(),
"decrees": session.query(func.count(Decree.id)).scalar(),
"general_assemblies": session.query(func.count(GeneralAssembly.id)).scalar(),
"bankruptcy_cases": session.query(func.count(BankruptcyCase.id)).scalar(),
"company_announcements": session.query(func.count(CompanyAnnouncement.id)).scalar()
}
total = sum(v for k, v in stats.items() if k != "gazette_issues")
summary = f"""TEXT_RESULT: Database Statistics:
📰 Gazette Issues: {stats['gazette_issues']}
📋 Tenders & Practices: {stats['tenders']}
🤝 Commercial Agencies: {stats['commercial_agencies']}
📜 Decrees: {stats['decrees']}
🏢 General Assemblies: {stats['general_assemblies']}
⚖️ Bankruptcy Cases: {stats['bankruptcy_cases']}
📢 Company Announcements: {stats['company_announcements']}
━━━━━━━━━━━━━━━━━━━━━
Total Entities: {total}"""
return summary
finally:
session.close()
# ============================================================================
# LANGGRAPH AGENT
# ============================================================================
# All tools available to the agent
ALL_TOOLS = [
count_entities,
list_entities,
search_tenders,
search_agencies,
get_database_stats
]
def create_agent_graph(model_config=None):
"""
Create the LangGraph agent with tools.
Args:
model_config: Optional model configuration from config.py
Returns:
Compiled LangGraph graph
"""
# Get model config
if model_config is None:
model_config = get_default_model()
if model_config is None:
raise ValueError("No model configured. Set GOOGLE_API_KEY or OPENROUTER_API_KEY")
# Create LLM based on provider
if model_config.provider == ModelProvider.GEMINI:
llm = ChatOpenAI(
model=model_config.model_id,
api_key=model_config.get_api_key(),
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
temperature=0.1,
max_tokens=2000
)
else:
# OpenRouter
llm = ChatOpenAI(
model=model_config.model_id,
api_key=model_config.get_api_key(),
base_url=model_config.base_url,
temperature=0.1,
max_tokens=2000,
default_headers={
"HTTP-Referer": "https://doctrace.ai",
"X-Title": "DocuTrace AI Auditor"
}
)
# Bind tools to LLM
llm_with_tools = llm.bind_tools(ALL_TOOLS)
# System prompt
SYSTEM_PROMPT = """You are Alda, an AI assistant specialized in analyzing Kuwait Al Youm gazette documents.
You help users understand government decrees, tenders, commercial agencies, bankruptcy cases, and company announcements.
IMPORTANT INSTRUCTIONS:
1. Use the provided tools to query the database for accurate information.
2. For counting questions (e.g., "how many"), use the count_entities tool.
3. For listing questions (e.g., "show me", "list"), use list_entities or specific search tools.
4. For overview questions, use get_database_stats tool.
5. Always call a tool when the question is about data in the gazette.
6. Provide clear, concise answers based on the tool results.
7. If a tool returns TABLE_RESULT, format the data nicely in your response.
8. If a tool returns COUNT_RESULT, present the number clearly.
9. If a tool returns TEXT_RESULT, present the summary directly.
Available entity types: tender, agency, decree, assembly, bankruptcy, announcement"""
# Define the agent node
def agent_node(state: MessagesState):
"""Process messages and decide whether to use tools."""
messages = state["messages"]
# Add system prompt if this is the first message
system_message = {"role": "system", "content": SYSTEM_PROMPT}
full_messages = [system_message] + messages
response = llm_with_tools.invoke(full_messages)
return {"messages": [response]}
# Build the graph
graph = StateGraph(MessagesState)
# Add nodes
graph.add_node("agent", agent_node)
graph.add_node("tools", ToolNode(ALL_TOOLS))
# Add edges
graph.add_edge(START, "agent")
graph.add_conditional_edges(
"agent",
tools_condition, # Built-in condition that checks for tool calls
{
"tools": "tools",
END: END
}
)
graph.add_edge("tools", "agent") # After tool execution, go back to agent
# Compile with memory
memory = InMemorySaver()
compiled_graph = graph.compile(checkpointer=memory)
logger.info(f"Created LangGraph agent with model: {model_config.display_name}")
return compiled_graph
def parse_tool_result(content: str) -> Dict[str, Any]:
"""
Parse tool result to extract type and data.
Returns:
Dict with 'type' (value/table/text) and 'data'
"""
if content.startswith("COUNT_RESULT:"):
# Extract the count number
text = content.replace("COUNT_RESULT:", "").strip()
return {"type": "value", "data": text}
elif content.startswith("TABLE_RESULT:"):
text = content.replace("TABLE_RESULT:", "").strip()
try:
# Try to parse as JSON
data = json.loads(text)
return {"type": "table", "data": data}
except json.JSONDecodeError:
return {"type": "text", "data": text}
elif content.startswith("TEXT_RESULT:"):
text = content.replace("TEXT_RESULT:", "").strip()
return {"type": "text", "data": text}
return {"type": "text", "data": content}
class AgentExecutor:
"""
High-level interface for the LangGraph agent.
Manages the agent lifecycle and provides a simple query interface.
"""
def __init__(self, model_config=None):
"""Initialize the agent executor."""
self.graph = create_agent_graph(model_config)
self.thread_id = "default"
self.model_config = model_config or get_default_model()
def set_thread(self, thread_id: str):
"""Set the conversation thread ID for memory."""
self.thread_id = thread_id
def query(self, user_message: str) -> Dict[str, Any]:
"""
Send a query to the agent and get a response.
Args:
user_message: The user's question
Returns:
Dict with:
- response: The final text response
- tool_calls: List of tools that were called
- output_type: 'value', 'table', or 'text'
- data: Structured data if applicable
"""
config = {"configurable": {"thread_id": self.thread_id}}
try:
# Invoke the graph
result = self.graph.invoke(
{"messages": [{"role": "user", "content": user_message}]},
config
)
messages = result["messages"]
# Extract tool calls from the conversation
tool_calls = []
tool_results = []
for msg in messages:
if hasattr(msg, 'tool_calls') and msg.tool_calls:
for tc in msg.tool_calls:
tool_calls.append({
"name": tc["name"],
"args": tc["args"]
})
# Check for tool messages
if hasattr(msg, 'type') and msg.type == "tool":
tool_results.append(msg.content)
# Get the final AI response
final_response = messages[-1].content if messages else "No response generated."
# Determine output type from tool results
output_type = "text"
structured_data = None
for result_content in tool_results:
parsed = parse_tool_result(result_content)
if parsed["type"] == "table" and isinstance(parsed["data"], list):
output_type = "table"
structured_data = parsed["data"]
break
elif parsed["type"] == "value":
output_type = "value"
structured_data = parsed["data"]
return {
"response": final_response,
"tool_calls": tool_calls,
"output_type": output_type,
"data": structured_data
}
except Exception as e:
logger.exception(f"Agent query error: {e}")
return {
"response": f"I encountered an error processing your request: {str(e)}",
"tool_calls": [],
"output_type": "text",
"data": None
}
# Create a global agent instance (lazy initialization)
_agent_instance: Optional[AgentExecutor] = None
def get_agent(model_config=None) -> AgentExecutor:
"""Get or create the global agent instance."""
global _agent_instance
if _agent_instance is None:
_agent_instance = AgentExecutor(model_config)
return _agent_instance
def reset_agent(model_config=None):
"""Reset the agent (useful when changing models)."""
global _agent_instance
_agent_instance = AgentExecutor(model_config)
return _agent_instance