Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import logging | |
| import uuid | |
| import asyncio | |
| from typing import Dict, Any, List, Optional, Set | |
| from textwrap import dedent | |
| from datetime import datetime | |
| # Load environment variables from .env file | |
| from dotenv import load_dotenv | |
| load_dotenv(os.path.join(os.path.dirname(__file__), '..', '.env')) | |
| # FastAPI imports for custom tenant-aware endpoint | |
| from fastapi import FastAPI, HTTPException, Body | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| # Updated imports for comprehensive tracking | |
| from agno.db.sqlite import SqliteDb # Changed from InMemoryDb for persistence | |
| from agno.agent import Agent | |
| from agno.models.ollama import Ollama | |
| from agno.os import AgentOS | |
| from agno.run import RunContext | |
| from agno.run.agent import RunEvent | |
| # Import the new multi-tenant toolkit | |
| from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit | |
| # Configuration for data sources API | |
| DATA_SOURCES_API_BASE_URL = os.environ.get("DATA_SOURCES_API_BASE_URL", "http://127.0.0.1:8000") | |
| DATA_SOURCES_API_KEY = os.environ.get("DATA_SOURCES_API_KEY") # Optional API key for authenticated requests | |
| print(f"π‘ Data Sources API URL: {DATA_SOURCES_API_BASE_URL}") | |
| if DATA_SOURCES_API_KEY: | |
| print("π Data Sources API Key configured.") | |
| else: | |
| print(" No Data Sources API Key configured (optional)") | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # NEW: Enhanced Tool Hook for Complete Logging | |
| def comprehensive_logging_hook( | |
| run_context: RunContext, | |
| function_name: str, | |
| function_call, | |
| arguments: Dict[str, Any] | |
| ) -> Any: | |
| """ | |
| Comprehensive tool execution logging hook that saves: | |
| - Tool name and arguments | |
| - Execution timestamp | |
| - Results | |
| - User context | |
| """ | |
| # Access session_state from run_context (Agno v2 API) | |
| if not run_context.session_state: | |
| run_context.session_state = {} | |
| session_state = run_context.session_state | |
| # Initialize logging structure in session state | |
| if "tool_execution_log" not in session_state: | |
| session_state["tool_execution_log"] = [] | |
| # Create execution record | |
| execution_start = datetime.now() | |
| execution_record = { | |
| "tool_name": function_name, | |
| "arguments": arguments, | |
| "timestamp": execution_start.isoformat(), | |
| "execution_id": f"{function_name}_{execution_start.timestamp()}" | |
| } | |
| logger.info(f"π§ Executing tool: {function_name} with args: {arguments}") | |
| try: | |
| # Execute the actual tool | |
| result = function_call(**arguments) | |
| # Log successful execution | |
| execution_end = datetime.now() | |
| execution_record.update({ | |
| "result": str(result)[:1000], # Truncate long results | |
| "status": "success", | |
| "duration_ms": (execution_end - execution_start).total_seconds() * 1000, | |
| "completed_at": execution_end.isoformat() | |
| }) | |
| logger.info(f"β Tool {function_name} completed successfully in {execution_record['duration_ms']:.2f}ms") | |
| except Exception as e: | |
| # Log failed execution | |
| execution_end = datetime.now() | |
| execution_record.update({ | |
| "error": str(e), | |
| "status": "failed", | |
| "duration_ms": (execution_end - execution_start).total_seconds() * 1000, | |
| "completed_at": execution_end.isoformat() | |
| }) | |
| logger.error(f"β Tool {function_name} failed: {str(e)}") | |
| raise # Re-raise the exception | |
| finally: | |
| # Always save the execution record | |
| session_state["tool_execution_log"].append(execution_record) | |
| return result | |
| system_prompt = dedent(""" | |
| <system_configuration> | |
| <persona> | |
| <name>Sirus</name> | |
| <creator>PhobosQ</creator> | |
| <role>Sirus The Data Scientist & Strategist</role> | |
| <mission>Bridge the gap between raw database rows and high-level business strategy.</mission> | |
| <voice>Professional, energetic, precise, and helpful. You speak in Markdown.</voice> | |
| </persona> | |
| <critical_directives> | |
| <directive id="1" name="The Invisible Wall"> | |
| The user CANNOT see your tool calls, JSON outputs, or SQL code. | |
| You MUST translate every tool result into a natural language sentence. | |
| NEVER end a turn with a tool call. ALWAYS end with a text response. | |
| </directive> | |
| <directive id="2" name="Broad Search Protocol"> | |
| Your semantic search is strict. When searching for tables, you MUST expand keywords. | |
| - If user asks: "How many users?" -> Search: ['users', 'accounts', 'customers', 'profiles','people','members'etc...] | |
| - If user asks: "Sales?" -> Search: ['sales', 'orders', 'transactions', 'revenue', 'invoices','bookings'] | |
| </directive> | |
| <directive id="3" name="The Schema Fallback"> | |
| If `find_relevant_tables` returns 0 matches, you MUST NOT give up. | |
| You MUST immediately call `get_available_sources_and_schema` to pull the full database map. | |
| Then, manually find the table and execute the query. | |
| </directive> | |
| <directive id="4" name="Safety & Read-Only"> | |
| NEVER execute INSERT, UPDATE, DELETE, DROP, or ALTER. | |
| ALWAYS use `LIMIT 100` on list queries to prevent token overflows. | |
| </directive> | |
| </critical_directives> | |
| <workflow_engine> | |
| <phase id="1" name="Initialization"> | |
| <check>Do I have the `source_instructions` in my context?</check> | |
| <action>If NO: Call `list_sources`, select the most relevant one, then `get_source_instructions`.</action> | |
| <action>If YES: Skip to Phase 2.</action> | |
| </phase> | |
| <phase id="2" name="Discovery"> | |
| <action>Call `find_relevant_tables(question, concepts)`.</action> | |
| <logic>Use broad concepts. If the user asks a "Why" question, search for fact tables (orders, logs) AND dimension tables (users, products).</logic> | |
| <fallback>If matches == 0: Call `get_available_sources_and_schema(tenant_id)`.</fallback> | |
| </phase> | |
| <phase id="3" name="Execution"> | |
| <action>Call `execute_sql_query(sql_query)`.</action> | |
| <logic> | |
| 1. Write Standard ANSI SQL. | |
| 2. Use the exact table names found in Phase 2. | |
| 3. If the user asks "Why" or "Trend", run aggregations (GROUP BY). | |
| </logic> | |
| <recovery>If SQL fails: Read error -> Correct Syntax -> Retry Query.</recovery> | |
| </phase> | |
| <phase id="4" name="Synthesis"> | |
| <action>Convert JSON list to Text.</action> | |
| <template> | |
| 1. **The Answer:** Direct answer to the question (e.g., "Total revenue is $5M"). | |
| 2. **The Context:** (Optional) "This is based on 500 records from the 'orders' table." | |
| 3. **The Strategy:** (Only for complex questions) "To improve this, consider..." | |
| 4. **Next Steps:** "Would you like to break this down by region?" | |
| </template> | |
| </phase> | |
| </workflow_engine> | |
| <tool_usage_guide> | |
| <tool name="list_sources"> | |
| <trigger>Start of conversation or when switching databases.</trigger> | |
| <purpose>Finds the tenant_id and source_name.</purpose> | |
| </tool> | |
| <tool name="get_source_instructions"> | |
| <trigger>Immediately after picking a source.</trigger> | |
| <purpose>Gets the "Manual" for the database (SQL dialect, special column rules).</purpose> | |
| </tool> | |
| <tool name="find_relevant_tables"> | |
| <trigger>Every user question.</trigger> | |
| <input_strategy> | |
| Argument `concepts` must be a list of broad synonyms. | |
| Example: User="Churn rate?" -> concepts=["churn", "status", "active", "cancelled", "users"] | |
| </input_strategy> | |
| </tool> | |
| <tool name="get_available_sources_and_schema"> | |
| <trigger>ONLY when `find_relevant_tables` fails (returns []).</trigger> | |
| <purpose>The "Nuclear Option". Dumps the whole schema so you can find tables manually.</purpose> | |
| </tool> | |
| <tool name="execute_sql_query"> | |
| <trigger>Once you have table names and a clear intent.</trigger> | |
| <rules> | |
| - No Markdown in the SQL string. | |
| - Dates should be handled dynamically (e.g., `CURRENT_DATE`). | |
| - Always handle NULLs in math operations (`COALESCE`). | |
| </rules> | |
| </tool> | |
| </tool_usage_guide> | |
| <exemplar_scenarios> | |
| <scenario type="Easy" description="Simple Count"> | |
| <user_input>How many users are there?</user_input> | |
| <chain_of_thought> | |
| 1. **Init**: I have instructions. | |
| 2. **Search**: `find_relevant_tables` for "users", "accounts". | |
| 3. **Result**: Found table `public.users`. | |
| 4. **Execute**: `SELECT COUNT(*) as count FROM public.users`. | |
| 5. **Result**: `[{'count': 5420}]`. | |
| 6. **Speak**: "There are currently **5,420 users** in your database. Is there a specific segment you'd like to analyze?" | |
| </chain_of_thought> | |
| </scenario> | |
| <scenario type="Complex" description="Trend Analysis & Strategy"> | |
| <user_input>Why did sales drop last month?</user_input> | |
| <chain_of_thought> | |
| 1. **Search**: `find_relevant_tables` for "sales", "orders", "revenue". | |
| 2. **Result**: Found `orders`. | |
| 3. **Reasoning**: To find "why", I need to look at categories or regions over time. A simple sum isn't enough. | |
| 4. **Execute**: `SELECT category, SUM(amount) as total FROM orders WHERE date >= DATE('now', '-1 month') GROUP BY category ORDER BY total ASC`. | |
| 5. **Result**: `[{'category': 'Tech', 'total': 500}, {'category': 'Home', 'total': 50000}]`. | |
| 6. **Speak**: "I analyzed the sales drop. It appears **Tech** products severely underperformed ($500) compared to Home Goods. | |
| **Strategic Insight:** This often indicates a stockout or a lack of promotion in that category. | |
| **Recommendation:** I suggest checking inventory levels for Tech products immediately." | |
| </chain_of_thought> | |
| </scenario> | |
| <scenario type="Failure_Recovery" description="Search returns Zero Matches"> | |
| <user_input>Show me the invoices.</user_input> | |
| <chain_of_thought> | |
| 1. **Search**: `find_relevant_tables(concepts=['invoices', 'billing'])`. | |
| 2. **Result**: `Matches: []` (Zero tables found). | |
| 3. **Pivot**: The vector search failed. I must see the schema. | |
| 4. **Tool**: `get_available_sources_and_schema(tenant_id=...)`. | |
| 5. **Result**: Full Schema JSON. I read it. I see a table named `billing_ledgers`. | |
| 6. **Execute**: `SELECT * FROM billing_ledgers LIMIT 5`. | |
| 7. **Speak**: "I couldn't find a table explicitly named 'invoices', but I found `billing_ledgers` which contains billing data. Here are the top 5 records..." | |
| </chain_of_thought> | |
| </scenario> | |
| </exemplar_scenarios> | |
| if u encounter any errors , kindly rectify them and proceed with the task at hand. if still its an server error or something , just say that kindly neatly. | |
| <output_formatting> | |
| - Use **Bold** for numbers and key entities. | |
| - Use Tables for lists of data. | |
| - Be concise but friendly. | |
| - Always ask a follow-up question. | |
| </output_formatting> | |
| </system_configuration> | |
| """) | |
| print("β Configuration set. Initializing enhanced agent with comprehensive logging...") | |
| # Initialize database for persistent storage | |
| agent_db = SqliteDb(db_file="agent_sessions.db") | |
| # Initialize toolkit with API configuration from environment | |
| data_sources_sql_toolkit = DataSourcesSQLToolkit( | |
| api_base_url=DATA_SOURCES_API_BASE_URL, | |
| api_key=DATA_SOURCES_API_KEY | |
| ) | |
| # FIX: Override default instructions so they don't conflict with Sirus | |
| # custom_reasoning_instructions = """ | |
| # Use `think` to plan your approach. | |
| # Use `analyze` to verify that your query result answers the user's specific question. | |
| # CRITICAL: After calling `analyze` with next_action="final_answer", you MUST output a natural language text response to the user. | |
| # The user cannot see your tool outputs - they only see your text replies. | |
| # Never end a conversation on a tool call. Always follow up with a clear, conversational response. | |
| # """ | |
| # # Initialize reasoning tools with simplified instructions | |
| # reasoning_tools = ReasoningTools( | |
| # instructions=custom_reasoning_instructions, # <--- OVERRIDE DEFAULTS | |
| # enable_analyze=False, | |
| # enable_think=True | |
| # ) | |
| # Define agent IDs for AgentOS | |
| DEFAULT_AGENT_OS_ID = os.getenv("SQL_AGENT_OS_ID", "sql-agent-os") | |
| DEFAULT_AGENT_ID = os.getenv("SQL_AGENT_ID", "sirus-sql-agent") | |
| # Create enhanced agent with comprehensive tracking | |
| gemini_sql_agent = Agent( | |
| model=Ollama( | |
| id="AgentCPM-Tools", # <--- UPDATED: Uses your new custom model with XML template | |
| host="http://ollama:11434", # Use Docker container name | |
| timeout=300, # 5-minute timeout to prevent infinite hangs on complex queries | |
| options={ | |
| "num_ctx": 32768, # Matches the context set in your Modelfile | |
| "temperature": 0.0, # CRITICAL: Forces strict adherence to XML tool tags | |
| "keep_alive": -1 # Keeps the model loaded in VRAM for speed | |
| } | |
| ), | |
| instructions=system_prompt, | |
| tools=[data_sources_sql_toolkit], | |
| tool_hooks=[comprehensive_logging_hook], | |
| tool_call_limit=100, | |
| # Enable debug mode to see raw XML output in logs if needed | |
| debug_mode=True, | |
| telemetry=False, | |
| # Database and session management | |
| db=agent_db, | |
| add_history_to_context=True, | |
| num_history_runs=3, | |
| read_chat_history=True, | |
| # Session state for tracking | |
| session_state={ | |
| "tool_execution_log": [], | |
| "user_context": {}, | |
| "analysis_metadata": {} | |
| }, | |
| add_session_state_to_context=True, | |
| # Response formatting | |
| markdown=True, | |
| add_datetime_to_context=True, | |
| # Error handling | |
| exponential_backoff=True, | |
| delay_between_retries=10 | |
| ) | |
| # Set agent ID for AgentOS | |
| gemini_sql_agent.id = DEFAULT_AGENT_ID | |
| # Set agent reference in toolkit so it can access session_state during tool execution | |
| # This is CRITICAL for session_state injection into tool calls | |
| data_sources_sql_toolkit.set_agent_ref(gemini_sql_agent) | |
| logger.info("Agent reference set in toolkit - session_state injection enabled") | |
| # Define Pydantic model for tenant-aware API requests | |
| class TenantRunRequest(BaseModel): | |
| """ | |
| Request model for our custom tenant-aware endpoint. | |
| This ensures all tenant context is provided in a single, secure request. | |
| Supports multi-source agent auto-detection when available_sources is provided. | |
| """ | |
| message: str | |
| supabase_jwt: str # JWT token for auth | |
| tenant_id: str # Extracted from JWT claims | |
| source_name: str # Default/primary source for query execution | |
| session_id: Optional[str] = None | |
| user_id: Optional[str] = None | |
| available_sources: Optional[list] = None # All available sources for agent auto-detection | |
| stream: bool = False | |
| # Define the tenant-aware endpoint function (will be added to AgentOS app later) | |
| async def run_tenant_agent( | |
| agent_id: str, | |
| request: TenantRunRequest | |
| ): | |
| """ | |
| Custom endpoint to run an agent with tenant_id, source_name, and supabase_jwt | |
| injected directly into the session_state. | |
| This is the PRIMARY endpoint for multi-tenant agent execution. | |
| It ensures proper tenant isolation and security by: | |
| 1. Accepting all tenant context in the request body | |
| 2. Injecting it into session_state (not shared between requests) | |
| 3. Using the JWT for data source API authentication | |
| Args: | |
| agent_id: The ID of the agent to run (e.g., "sirus-sql-agent") | |
| request: TenantRunRequest containing all tenant context | |
| Returns: | |
| StreamingResponse (if stream=True) or direct JSON response | |
| """ | |
| # Get agent from the global agent we created | |
| agent = gemini_sql_agent if agent_id == DEFAULT_AGENT_ID else None | |
| if not agent: | |
| raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found.") | |
| # CRITICAL: This is the state that will be loaded *for this run only*. | |
| # This is the correct, request-safe way to handle per-run context. | |
| # Each request gets its own isolated session_state. | |
| initial_state = { | |
| "supabase_jwt": request.supabase_jwt, # JWT for backend API auth | |
| "tenant_id": request.tenant_id, # Tenant context for toolkit | |
| "source_name": request.source_name, | |
| "user_id": request.user_id, | |
| "available_sources": request.available_sources or [], # All sources for agent auto-detection | |
| "tool_execution_log": [], | |
| "user_context": {}, | |
| "analysis_metadata": {} | |
| } | |
| # Generate a session ID if not provided | |
| session_id = request.session_id or str(uuid.uuid4()) | |
| logger.info(f"π Starting tenant run for tenant_id={request.tenant_id}, source={request.source_name}, session={session_id}") | |
| if request.stream: | |
| # Handle streaming response for real-time agent output | |
| async def stream_generator(): | |
| try: | |
| logger.info(f"π¬ Starting streaming for session {session_id}, message: {request.message[:50]}...") | |
| # agent.run returns a generator in stream mode | |
| response_generator = agent.run( | |
| request.message, | |
| stream=True, | |
| stream_events=True, # Enable full event streaming for tool calls | |
| session_id=session_id, | |
| session_state=initial_state # <-- **** THIS IS THE FIX **** | |
| ) | |
| chunk_count = 0 | |
| for chunk in response_generator: | |
| chunk_count += 1 | |
| # Handle RunEvent types for proper streaming | |
| if hasattr(chunk, 'event'): | |
| logger.info(f" [Chunk {chunk_count}] Event: {chunk.event}") | |
| if chunk.event == RunEvent.run_content: | |
| # Model text response | |
| event_data = {"content": chunk.content if hasattr(chunk, 'content') else str(chunk)} | |
| sse_event = f"event: RunContent\ndata: {json.dumps(event_data)}\n\n" | |
| yield sse_event | |
| logger.info(f" β Yielded RunContent event") | |
| elif chunk.event == RunEvent.tool_call_started: | |
| # Tool starting | |
| tool_name = chunk.tool.tool_name if hasattr(chunk, 'tool') and hasattr(chunk.tool, 'tool_name') else 'unknown' | |
| event_data = {"tool": tool_name, "status": "started"} | |
| sse_event = f"event: ToolCallStarted\ndata: {json.dumps(event_data)}\n\n" | |
| yield sse_event | |
| logger.info(f" β Yielded ToolCallStarted: {tool_name}") | |
| elif chunk.event == RunEvent.tool_call_completed: | |
| # Tool finished | |
| tool_name = chunk.tool.tool_name if hasattr(chunk, 'tool') and hasattr(chunk.tool, 'tool_name') else 'unknown' | |
| result_preview = str(chunk.content)[:200] if hasattr(chunk, 'content') else 'completed' | |
| event_data = {"tool": tool_name, "status": "completed", "result_preview": result_preview} | |
| sse_event = f"event: ToolCallCompleted\ndata: {json.dumps(event_data)}\n\n" | |
| yield sse_event | |
| logger.info(f" β Yielded ToolCallCompleted: {tool_name}") | |
| else: | |
| # Other event types | |
| logger.info(f" β οΈ Unhandled event type: {chunk.event}") | |
| await asyncio.sleep(0.001) | |
| continue | |
| # Fallback for dict-based chunks | |
| if isinstance(chunk, dict): | |
| event = chunk.get("event") | |
| data = chunk.get("data") | |
| if event: | |
| sse_event = f"event: {event}\ndata: {json.dumps(data)}\n\n" | |
| else: | |
| sse_event = f"data: {json.dumps(chunk)}\n\n" | |
| yield sse_event | |
| logger.info(f" β Yielded event: {event or 'data-only'}") | |
| # Small delay to ensure chunk is flushed before next one | |
| await asyncio.sleep(0.001) | |
| else: | |
| # Handle Pydantic objects or other objects | |
| try: | |
| logger.info(f"Processing chunk type: {type(chunk)}") | |
| # Try multiple serialization methods | |
| chunk_dict = None | |
| # Method 1: Pydantic v2 model_dump() | |
| if hasattr(chunk, 'model_dump'): | |
| try: | |
| chunk_dict = chunk.model_dump() | |
| logger.info(f"β Serialized with model_dump()") | |
| except Exception as e: | |
| logger.info(f"model_dump() failed: {e}") | |
| # Method 2: Pydantic v1 dict() | |
| if chunk_dict is None and hasattr(chunk, 'dict'): | |
| try: | |
| chunk_dict = chunk.dict() | |
| logger.info(f"β Serialized with dict()") | |
| except Exception as e: | |
| logger.info(f"dict() failed: {e}") | |
| # Method 3: Check if it's a Pydantic BaseModel | |
| if chunk_dict is None: | |
| try: | |
| # Try to import and check | |
| from pydantic import BaseModel | |
| if isinstance(chunk, BaseModel): | |
| chunk_dict = chunk.model_dump() | |
| logger.info(f"β Serialized BaseModel with model_dump()") | |
| except Exception as e: | |
| logger.info(f"BaseModel check failed: {e}") | |
| # Method 4: Fall back to __dict__ | |
| if chunk_dict is None and hasattr(chunk, '__dict__'): | |
| chunk_dict = chunk.__dict__ | |
| logger.info(f"β Serialized with __dict__") | |
| # Method 5: Last resort - convert to string | |
| if chunk_dict is None: | |
| logger.warning(f"Could not serialize chunk, converting to string: {type(chunk)}") | |
| chunk_dict = {"content": str(chunk)} | |
| # Extract event type if present | |
| event_type = chunk_dict.get("event") | |
| if event_type: | |
| logger.info(f"Sending event: {event_type}") | |
| # Debug: Show content for ReasoningStep events | |
| if event_type == "ReasoningStep": | |
| logger.info(f" ReasoningStep content: reasoning={chunk_dict.get('reasoning')}, content={chunk_dict.get('content')}, result={chunk_dict.get('result')}") | |
| logger.info(f" Full ReasoningStep dict keys: {list(chunk_dict.keys())}") | |
| # Use custom serializer that properly handles nested objects | |
| def serialize_value(obj): | |
| """Recursively serialize objects, converting to strings only when necessary""" | |
| if isinstance(obj, dict): | |
| return {k: serialize_value(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return [serialize_value(v) for v in obj] | |
| elif hasattr(obj, 'model_dump'): | |
| return serialize_value(obj.model_dump()) | |
| elif hasattr(obj, '__dict__') and not isinstance(obj, (str, int, float, bool, type(None))): | |
| return serialize_value(obj.__dict__) | |
| else: | |
| return obj | |
| serialized_dict = serialize_value(chunk_dict) | |
| # Special handling for ReasoningStep: convert content object to string | |
| if event_type == "ReasoningStep" and isinstance(serialized_dict.get("content"), dict): | |
| # Content is a reasoning object - serialize it as string for frontend | |
| reasoning_obj = serialized_dict.pop("content") | |
| serialized_dict["reasoning_content"] = json.dumps(reasoning_obj, default=str, ensure_ascii=False) | |
| logger.info(f" β Converted ReasoningStep content to reasoning_content string") | |
| sse_event = f"event: {event_type}\ndata: {json.dumps(serialized_dict, default=str, ensure_ascii=False)}\n\n" | |
| else: | |
| logger.info(f"Sending data without event type") | |
| def serialize_value(obj): | |
| """Recursively serialize objects, converting to strings only when necessary""" | |
| if isinstance(obj, dict): | |
| return {k: serialize_value(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return [serialize_value(v) for v in obj] | |
| elif hasattr(obj, 'model_dump'): | |
| return serialize_value(obj.model_dump()) | |
| elif hasattr(obj, '__dict__') and not isinstance(obj, (str, int, float, bool, type(None))): | |
| return serialize_value(obj.__dict__) | |
| else: | |
| return obj | |
| serialized_dict = serialize_value(chunk_dict) | |
| sse_event = f"data: {json.dumps(serialized_dict, default=str, ensure_ascii=False)}\n\n" | |
| yield sse_event | |
| logger.info(f" β Yielded event: {event_type or 'data-only'}") | |
| # Small delay to ensure chunk is flushed before next one | |
| await asyncio.sleep(0.001) | |
| except Exception as e: | |
| logger.error(f"Failed to serialize chunk: {e}, chunk type: {type(chunk)}", exc_info=True) | |
| yield f"data: {json.dumps({'error': str(e), 'content': str(chunk)}, default=str)}\n\n" | |
| await asyncio.sleep(0.001) | |
| logger.info(f"β Streaming run completed for session {session_id} - sent {chunk_count} chunks") | |
| except Exception as e: | |
| logger.error(f"β Error during stream generation for session {session_id}: {e}", exc_info=True) | |
| error_data = {"error": str(e), "code": "STREAM_ERROR"} | |
| yield f"event: error\ndata: {json.dumps(error_data)}\n\n" | |
| return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
| else: | |
| # Handle non-streaming (blocking) response | |
| try: | |
| response = agent.run( | |
| request.message, | |
| stream=False, | |
| session_id=session_id, | |
| session_state=initial_state # <-- **** THIS IS THE FIX **** | |
| ) | |
| logger.info(f"β Non-streaming run completed for session {session_id}") | |
| # The final response from agent.run is the message content | |
| return { | |
| "session_id": session_id, | |
| "tenant_id": request.tenant_id, | |
| "response": response | |
| } | |
| except Exception as e: | |
| logger.error(f"β Error during non-streaming agent run for session {session_id}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Create AgentOS (without fastapi_app - that parameter doesn't exist in current agno version) | |
| agent_os = AgentOS( | |
| agents=[gemini_sql_agent], | |
| description="Multi-tenant SQL Agent for querying data sources across tenants." | |
| ) | |
| # Get the AgentOS app first, then add our custom route to it | |
| agentOS_app = agent_os.get_app() | |
| # Add our custom /tenant-run endpoint to the AgentOS app | |
| agentOS_app.add_api_route( | |
| "/tenant-run/{agent_id}", | |
| run_tenant_agent, | |
| methods=["POST"], | |
| name="run_tenant_agent" | |
| ) | |
| # Use the combined app | |
| app = agentOS_app | |
| # DEPRECATED FUNCTIONS - Replaced by the /tenant-run API endpoint | |
| # The following functions are kept for backward compatibility and local testing only. | |
| # For production API usage, use the /tenant-run/{agent_id} endpoint instead. | |
| # DEPRECATED FUNCTIONS - Replaced by the /tenant-run API endpoint | |
| # The following functions are kept for backward compatibility and local testing only. | |
| # For production API usage, use the /tenant-run/{agent_id} endpoint instead. | |
| if __name__ == "__main__": | |
| import uvicorn | |
| host = os.getenv("SQL_AGENT_HOST", "0.0.0.0") | |
| port = int(os.getenv("SQL_AGENT_PORT", "5559")) | |
| print("\n" + "="*80) | |
| print("π STARTING SQL AGENT OS SERVER (with custom /tenant-run endpoint)") | |
| print("="*80) | |
| print(f"Host: {host}") | |
| print(f"Port: {port}") | |
| print(f"Agent ID: {DEFAULT_AGENT_ID}") | |
| print(f"AgentOS ID: {DEFAULT_AGENT_OS_ID}") | |
| print("="*80 + "\n") | |
| print(f"\nπ― CUSTOM TENANT ENDPOINT:") | |
| print(f" POST http://{host}:{port}/tenant-run/{DEFAULT_AGENT_ID}") | |
| print(f"\nπ STANDARD AGENTOS ENDPOINTS:") | |
| print(f" GET http://{host}:{port}/config") | |
| print(f" GET http://{host}:{port}/agents") | |
| print(f" POST http://{host}:{port}/agents/{DEFAULT_AGENT_ID}/runs") | |
| print("="*80 + "\n") | |
| # Run with proper streaming settings | |
| uvicorn.run( | |
| app, | |
| host=host, | |
| port=port, | |
| # Streaming settings - prevent buffering | |
| server_header=False, | |
| # Disable app level buffering - let streaming work properly | |
| interface="auto" | |
| ) |