Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import logging | |
| import uuid | |
| import asyncio | |
| import sys | |
| from typing import Dict, Any, List, Optional, Set | |
| from textwrap import dedent | |
| from datetime import datetime | |
| # Load environment variables from .env file | |
| from dotenv import load_dotenv | |
| load_dotenv(os.path.join(os.path.dirname(__file__), '..', '.env')) | |
| # FastAPI imports for custom tenant-aware endpoint | |
| from fastapi import FastAPI, HTTPException, Body, Depends, Request | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from backend.core.auth import get_current_user, AuthUser | |
| # Updated imports for comprehensive tracking | |
| from agno.db.sqlite import SqliteDb # Changed from InMemoryDb for persistence | |
| from agno.agent import Agent | |
| from agno.models.nvidia import Nvidia | |
| from agno.os import AgentOS | |
| from agno.run import RunContext | |
| from agno.run.agent import RunEvent | |
| # Import the new multi-tenant toolkit | |
| from backend.SQL_Agent.data_sources_sql_toolkit import DataSourcesSQLToolkit | |
| from backend.SQL_Agent.tenant_file_toolkit import TenantFileToolkit | |
| # Configuration for data sources API | |
| DATA_SOURCES_API_BASE_URL = os.environ.get("DATA_SOURCES_API_BASE_URL", "http://127.0.0.1:8000") | |
| DATA_SOURCES_API_KEY = os.environ.get("DATA_SOURCES_API_KEY") # Optional API key for authenticated requests | |
| print(f"📡 Data Sources API URL: {DATA_SOURCES_API_BASE_URL}") | |
| if DATA_SOURCES_API_KEY: | |
| print("🔑 Data Sources API Key configured.") | |
| else: | |
| print(" No Data Sources API Key configured (optional)") | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def _get_billing_redis(): | |
| try: | |
| import redis | |
| redis_url = os.environ.get("REDIS_URL") | |
| if redis_url: | |
| return redis.from_url(redis_url, decode_responses=True) | |
| redis_host = os.environ.get("REDIS_HOST") | |
| redis_port = int(os.environ.get("REDIS_PORT", "6379")) | |
| redis_db = int(os.environ.get("REDIS_DB", "0")) | |
| redis_password = os.environ.get("REDIS_PASSWORD") | |
| if redis_host: | |
| return redis.Redis( | |
| host=redis_host, | |
| port=redis_port, | |
| db=redis_db, | |
| password=redis_password, | |
| decode_responses=True, | |
| ) | |
| return None | |
| except Exception as exc: | |
| logger.warning(f"Billing Redis unavailable: {exc}") | |
| return None | |
| def record_tenant_billing(tenant_id: str, input_tokens: int, output_tokens: int) -> None: | |
| if not tenant_id: | |
| return | |
| billing_redis = _get_billing_redis() | |
| if billing_redis is None: | |
| return | |
| input_tokens = int(input_tokens or 0) | |
| output_tokens = int(output_tokens or 0) | |
| total_tokens = input_tokens + output_tokens | |
| billing_key = f"tenant_billing:{tenant_id}" | |
| billing_redis.hincrby(billing_key, "input_tokens", input_tokens) | |
| billing_redis.hincrby(billing_key, "output_tokens", output_tokens) | |
| billing_redis.hincrby(billing_key, "total_tokens", total_tokens) | |
| est_cost = (input_tokens / 1_000_000) * 0.15 + (output_tokens / 1_000_000) * 0.60 | |
| billing_redis.hincrbyfloat(billing_key, "estimated_cost_usd", float(f"{est_cost:.6f}")) | |
| # NEW: Enhanced Tool Hook for Complete Logging | |
| def comprehensive_logging_hook( | |
| run_context: RunContext, | |
| function_name: str, | |
| function_call, | |
| arguments: Dict[str, Any] | |
| ) -> Any: | |
| """ | |
| Comprehensive tool execution logging hook that saves: | |
| - Tool name and arguments | |
| - Execution timestamp | |
| - Results | |
| - User context | |
| """ | |
| # Access session_state from run_context (Agno v2 API) | |
| if not run_context.session_state: | |
| run_context.session_state = {} | |
| session_state = run_context.session_state | |
| # Initialize logging structure in session state | |
| if "tool_execution_log" not in session_state: | |
| session_state["tool_execution_log"] = [] | |
| # Create execution record | |
| execution_start = datetime.now() | |
| execution_record = { | |
| "tool_name": function_name, | |
| "arguments": arguments, | |
| "timestamp": execution_start.isoformat(), | |
| "execution_id": f"{function_name}_{execution_start.timestamp()}" | |
| } | |
| logger.info(f"🔧 Executing tool: {function_name} with args: {arguments}") | |
| try: | |
| # Execute the actual tool | |
| result = function_call(**arguments) | |
| # Log successful execution | |
| execution_end = datetime.now() | |
| execution_record.update({ | |
| "result": str(result)[:1000], # Truncate long results | |
| "status": "success", | |
| "duration_ms": (execution_end - execution_start).total_seconds() * 1000, | |
| "completed_at": execution_end.isoformat() | |
| }) | |
| logger.info(f"✅ Tool {function_name} completed successfully in {execution_record['duration_ms']:.2f}ms") | |
| except Exception as e: | |
| # Log failed execution | |
| execution_end = datetime.now() | |
| execution_record.update({ | |
| "error": str(e), | |
| "status": "failed", | |
| "duration_ms": (execution_end - execution_start).total_seconds() * 1000, | |
| "completed_at": execution_end.isoformat() | |
| }) | |
| logger.error(f"❌ Tool {function_name} failed: {str(e)}") | |
| raise # Re-raise the exception | |
| finally: | |
| # Always save the execution record | |
| session_state["tool_execution_log"].append(execution_record) | |
| return result | |
| system_prompt = dedent(""" | |
| <system_configuration> | |
| <persona> | |
| <name>Sirus</name> | |
| <creator>PhobosQ</creator> | |
| <role>Sirus The Data Scientist & Strategist</role> | |
| <mission>Bridge the gap between raw database rows and high-level business strategy.</mission> | |
| <voice>Professional, energetic, precise, and helpful. You speak in Markdown.</voice> | |
| </persona> | |
| <critical_directives> | |
| <directive id="1" name="The Invisible Wall"> | |
| The user CANNOT see your tool calls, JSON outputs, or SQL code. | |
| You MUST translate every tool result into a natural language sentence. | |
| NEVER end a turn with a tool call. ALWAYS end with a text response. | |
| </directive> | |
| <directive id="2" name="Broad Search Protocol"> | |
| Your semantic search is strict. When searching for tables, you MUST expand keywords. | |
| - If user asks: "How many users?" -> Search: ['users', 'accounts', 'customers', 'profiles','people','members'etc...] | |
| - If user asks: "Sales?" -> Search: ['sales', 'orders', 'transactions', 'revenue', 'invoices','bookings'] | |
| </directive> | |
| <directive id="3" name="The Schema Fallback"> | |
| If `find_relevant_tables` returns 0 matches, you MUST NOT give up. | |
| You MUST immediately call `get_available_sources_and_schema` to pull the full database map. | |
| Then, manually find the table and execute the query. | |
| </directive> | |
| <directive id="4" name="Safety & Read-Only"> | |
| NEVER execute INSERT, UPDATE, DELETE, DROP, or ALTER. | |
| ALWAYS use `LIMIT 100` on list queries to prevent token overflows. | |
| </directive> | |
| </critical_directives> | |
| <workflow_engine> | |
| <phase id="1" name="Initialization"> | |
| <check>Do I have the `source_instructions` in my context?</check> | |
| <action>If NO: Call `list_sources`, select the most relevant one, then `get_source_instructions`.</action> | |
| <action>If YES: Skip to Phase 2.</action> | |
| </phase> | |
| <phase id="2" name="Discovery"> | |
| <action>Call `find_relevant_tables(question, concepts)`.</action> | |
| <logic>Use broad concepts. If the user asks a "Why" question, search for fact tables (orders, logs) AND dimension tables (users, products).</logic> | |
| <fallback>If matches == 0: Call `get_available_sources_and_schema(tenant_id)`.</fallback> | |
| </phase> | |
| <phase id="3" name="Execution"> | |
| <action>Call `execute_sql_query(sql_query)` OR `save_query_to_tenant_csv(sql_query)`.</action> | |
| <logic> | |
| 1. Write Standard ANSI SQL. | |
| 2. Use the exact table names found in Phase 2. | |
| 3. If the user asks "Why" or "Trend", run aggregations (GROUP BY). | |
| 4. **CRITICAL ML RULE:** If the user asks to "save", "export", "analyze in pandas", or "prepare for ML", you MUST use `save_query_to_tenant_csv`. | |
| </logic> | |
| <recovery>If SQL fails: Read error -> Correct Syntax -> Retry Query.</recovery> | |
| </phase> | |
| <phase id="4" name="Synthesis"> | |
| <action>Convert JSON list to Text.</action> | |
| <template> | |
| 1. **The Answer:** Direct answer to the question (e.g., "Total revenue is $5M"). | |
| 2. **The Context:** (Optional) "This is based on 500 records from the 'orders' table." | |
| 3. **The Strategy:** (Only for complex questions) "To improve this, consider..." | |
| 4. **Suggested Questions:** ALWAYS end your response with exactly 3 highly relevant follow-up questions formatted as a bulleted list under the exact heading `### Suggested Questions`. | |
| </template> | |
| </phase> | |
| </workflow_engine> | |
| <tool_usage_guide> | |
| <tool name="list_sources"> | |
| <trigger>Start of conversation or when switching databases.</trigger> | |
| <purpose>Finds the tenant_id and source_name.</purpose> | |
| </tool> | |
| <tool name="get_source_instructions"> | |
| <trigger>Immediately after picking a source.</trigger> | |
| <purpose>Gets the "Manual" for the database (SQL dialect, special column rules).</purpose> | |
| </tool> | |
| <tool name="find_relevant_tables"> | |
| <trigger>Every user question.</trigger> | |
| <input_strategy> | |
| Argument `concepts` must be a list of broad synonyms. | |
| Example: User="Churn rate?" -> concepts=["churn", "status", "active", "cancelled", "users"] | |
| </input_strategy> | |
| </tool> | |
| <tool name="get_available_sources_and_schema"> | |
| <trigger>ONLY when `find_relevant_tables` fails (returns []).</trigger> | |
| <purpose>The "Nuclear Option". Dumps the whole schema so you can find tables manually.</purpose> | |
| </tool> | |
| <tool name="execute_sql_query"> | |
| <trigger>Once you have table names and a clear intent for a simple data pull or counting.</trigger> | |
| <rules> | |
| - No Markdown in the SQL string. | |
| - Dates should be handled dynamically (e.g., `CURRENT_DATE`). | |
| - Always handle NULLs in math operations (`COALESCE`). | |
| </rules> | |
| </tool> | |
| <tool name="save_query_to_tenant_csv"> | |
| <trigger>When a user asks to export data, prepare it for Machine Learning, or save it.</trigger> | |
| <purpose>Executes SQL but completely bypasses standard memory limits by saving directly to the MinIO cluster.</purpose> | |
| </tool> | |
| <tool name="list_tenant_assets"> | |
| <trigger>When a user asks what files, reports, or datasets they have in their workspace.</trigger> | |
| </tool> | |
| <tool name="load_tenant_file_to_dataframe"> | |
| <trigger>When a user asks you to analyze a specific CSV file in their workspace.</trigger> | |
| </tool> | |
| </tool_usage_guide> | |
| <exemplar_scenarios> | |
| <scenario type="Easy" description="Simple Count"> | |
| <user_input>How many users are there?</user_input> | |
| <chain_of_thought> | |
| 1. **Init**: I have instructions. | |
| 2. **Search**: `find_relevant_tables` for "users", "accounts". | |
| 3. **Result**: Found table `public.users`. | |
| 4. **Execute**: `SELECT COUNT(*) as count FROM public.users`. | |
| 5. **Result**: `[{'count': 5420}]`. | |
| 6. **Speak**: "There are currently **5,420 users** in your database. Is there a specific segment you'd like to analyze?" | |
| </chain_of_thought> | |
| </scenario> | |
| <scenario type="Complex" description="Trend Analysis & Strategy"> | |
| <user_input>Why did sales drop last month?</user_input> | |
| <chain_of_thought> | |
| 1. **Search**: `find_relevant_tables` for "sales", "orders", "revenue". | |
| 2. **Result**: Found `orders`. | |
| 3. **Reasoning**: To find "why", I need to look at categories or regions over time. A simple sum isn't enough. | |
| 4. **Execute**: `SELECT category, SUM(amount) as total FROM orders WHERE date >= DATE('now', '-1 month') GROUP BY category ORDER BY total ASC`. | |
| 5. **Result**: `[{'category': 'Tech', 'total': 500}, {'category': 'Home', 'total': 50000}]`. | |
| 6. **Speak**: "I analyzed the sales drop. It appears **Tech** products severely underperformed ($500) compared to Home Goods. | |
| **Strategic Insight:** This often indicates a stockout or a lack of promotion in that category. | |
| **Recommendation:** I suggest checking inventory levels for Tech products immediately." | |
| </chain_of_thought> | |
| </scenario> | |
| <scenario type="Failure_Recovery" description="Search returns Zero Matches"> | |
| <user_input>Show me the invoices.</user_input> | |
| <chain_of_thought> | |
| 1. **Search**: `find_relevant_tables(concepts=['invoices', 'billing'])`. | |
| 2. **Result**: `Matches: []` (Zero tables found). | |
| 3. **Pivot**: The vector search failed. I must see the schema. | |
| 4. **Tool**: `get_available_sources_and_schema(tenant_id=...)`. | |
| 5. **Result**: Full Schema JSON. I read it. I see a table named `billing_ledgers`. | |
| 6. **Execute**: `SELECT * FROM billing_ledgers LIMIT 5`. | |
| 7. **Speak**: "I couldn't find a table explicitly named 'invoices', but I found `billing_ledgers` which contains billing data. Here are the top 5 records..." | |
| </chain_of_thought> | |
| </scenario> | |
| <scenario type="File_Analysis" description="User asks about an Excel or CSV file"> | |
| <user_input>Analyze my demographic data file.</user_input> | |
| <chain_of_thought> | |
| 1. **Search**: I don't know the exact file name. I need to list them. | |
| 2. **Tool**: `list_tenant_assets()`. | |
| 3. **Result**: `[{"asset_id": "123", "filename": "Demographic Data.xlsx"}]`. | |
| 4. **Tool**: `load_tenant_file_to_dataframe(asset_id="123", instructions="Summarize the key demographics.")`. | |
| 5. **Result**: Pandas output showing age distributions and counts. | |
| 6. **Speak**: "I've analyzed your demographic dataset. The majority of the entries fall into..." | |
| </chain_of_thought> | |
| </scenario> | |
| <scenario type="Hybrid_Workflow" description="Combining Database and File Insights"> | |
| <user_input>Compare the regional sales in the DB with the targets in the targets.xlsx file.</user_input> | |
| <chain_of_thought> | |
| 1. **DB Step**: I need regional sales from the database. | |
| 2. **Tool**: `find_relevant_tables(concepts=['sales', 'regions', 'geography'])`. | |
| 3. **Execute**: `SELECT region, SUM(amount) FROM regional_sales GROUP BY region`. | |
| 4. **File Step**: Now I need the targets. | |
| 5. **Tool**: `list_tenant_assets()`. | |
| 6. **Tool**: `load_tenant_file_to_dataframe(filename="targets.xlsx", instructions="Extract the targets by region.")`. | |
| 7. **Synthesize**: I combine the SQL results and Pandas results in my reasoning. | |
| 8. **Speak**: "Comparing your database records to the Excel file, the North region hit its target ($50k vs $45k expected), but..." | |
| </chain_of_thought> | |
| </scenario> | |
| <scenario type="Common_Mistake_Correction" description="Assuming DB tables for File requests"> | |
| <user_input>What does the file users.csv say?</user_input> | |
| <chain_of_thought> | |
| 1. **BAD PATH**: Calling `execute_sql_query('SELECT * FROM "users.csv"')`. (This is WRONG! It's a file, not a table). | |
| 2. **CORRECT PATH**: The user specifically said "file" and "csv". | |
| 3. **Tool**: `list_tenant_assets()` to verify it exists. | |
| 4. **Tool**: `load_tenant_file_to_dataframe(filename="users.csv", ...)`. | |
| 5. **Speak**: "I loaded the users.csv file and found 300 entries..." | |
| </chain_of_thought> | |
| </scenario> | |
| <scenario type="Ambiguous_Request" description="User is vague; check both DB and files"> | |
| <user_input>Can you show me some insights?</user_input> | |
| <chain_of_thought> | |
| 1. **Recognize Ambiguity**: The user never said "table" or "file". I must sample both data sources. | |
| 2. **DB Probe**: `find_relevant_tables(concepts=['users','sales','orders','activity','logs'])` to surface likely tables. | |
| 3. **File Probe**: `list_tenant_assets()` to see if any CSV/Excel assets exist that look relevant (recent uploads, names with "report", "data", etc.). | |
| 4. **Pick One of Each (Lightweight)**: Grab a tiny preview: `execute_sql_query('SELECT * FROM <top_table> LIMIT 5')` and `load_tenant_file_to_dataframe(asset_id=<id>, instructions="Give me a quick summary")`. | |
| 5. **Synthesize**: Combine the quick peeks and present the clearest starting point. Offer options: continue with DB analysis, or dive into the file. | |
| 6. **Speak**: "I checked both your database and uploaded files. From the database, I saw a table with recent activity; from files, there's a recent report.xlsx. Which one should I dig into further?" | |
| </chain_of_thought> | |
| </scenario> | |
| </exemplar_scenarios> | |
| <file_and_hybrid_workflow> | |
| <directive id="5" name="File Tool Prioritization"> | |
| When a user explicitly mentions a "file", "csv", "excel", "dataset", or "xlsx", you MUST prioritize file-based tools (`list_tenant_assets`, `load_tenant_file_to_dataframe`). | |
| DO NOT try to query these files using standard SQL tools unless explicitly attached as a temporary table (which they are not). Files live in a separate blob storage; databases live in SQL. | |
| </directive> | |
| <directive id="6" name="Hybrid Analytics Protocol"> | |
| If the user asks a question that spans both their database AND an uploaded file: | |
| 1. Extract the DB information first using Phase 2 (Discovery) and Phase 3 (Execution). | |
| 2. Extract the File information second by locating the file with `list_tenant_assets` and querying it with `load_tenant_file_to_dataframe`. | |
| 3. Synthesize the findings using your own reasoning to combine the disparate data sources. | |
| </directive> | |
| <directive id="7" name="File Tool Self-Correction"> | |
| If a tool call to `load_tenant_file_to_dataframe` fails with "File not found" or "NoSuchKey", DO NOT confidently report that the data is missing. Instead, ALWAYS call `list_tenant_assets` to double-check the exact spelling, path, or `asset_id` of the available files and try again using the exact identifier. | |
| </directive> | |
| </file_and_hybrid_workflow> | |
| if u encounter any errors , kindly rectify them and proceed with the task at hand. if still its an server error or something , just say that kindly neatly. | |
| <output_formatting> | |
| - Use **Bold** for numbers and key entities. | |
| - Use Tables for lists of data. | |
| - Be concise but friendly. | |
| - CRITICAL: You MUST ALWAYS append exactly 3 follow-up questions under the exact markdown heading `### Suggested Questions` at the very end of your response. | |
| </output_formatting> | |
| </system_configuration> | |
| """) | |
| print("✅ Configuration set. Initializing enhanced agent with comprehensive logging...") | |
| # Define agent IDs for AgentOS | |
| DEFAULT_AGENT_OS_ID = os.getenv("SQL_AGENT_OS_ID", "sql-agent-os") | |
| DEFAULT_AGENT_ID = os.getenv("SQL_AGENT_ID", "sirus-sql-agent") | |
| IS_PYTEST = "PYTEST_CURRENT_TEST" in os.environ or "pytest" in sys.modules | |
| agent_db = None | |
| data_sources_sql_toolkit = None | |
| tenant_file_toolkit = None | |
| gemini_sql_agent = None | |
| if not IS_PYTEST: | |
| agent_db = SqliteDb(db_file="agent_sessions.db") | |
| data_sources_sql_toolkit = DataSourcesSQLToolkit( | |
| api_base_url=DATA_SOURCES_API_BASE_URL, | |
| api_key=DATA_SOURCES_API_KEY | |
| ) | |
| tenant_file_toolkit = TenantFileToolkit() | |
| gemini_sql_agent = Agent( | |
| model=Nvidia( | |
| id="stepfun-ai/step-3.7-flash", | |
| #id="nvidia/nemotron-3-super-120b-a12b", | |
| max_tokens=32768, | |
| temperature=0.2, | |
| top_p=0.95 | |
| ), | |
| instructions=system_prompt, | |
| tools=[data_sources_sql_toolkit, tenant_file_toolkit], | |
| tool_hooks=[comprehensive_logging_hook], | |
| tool_call_limit=100, | |
| debug_mode=True, | |
| telemetry=False, | |
| db=agent_db, | |
| add_history_to_context=True, | |
| num_history_runs=3, | |
| read_chat_history=True, | |
| session_state={ | |
| "tool_execution_log": [], | |
| "user_context": {}, | |
| "analysis_metadata": {} | |
| }, | |
| add_session_state_to_context=True, | |
| markdown=True, | |
| add_datetime_to_context=True, | |
| exponential_backoff=True, | |
| delay_between_retries=10 | |
| ) | |
| gemini_sql_agent.id = DEFAULT_AGENT_ID | |
| data_sources_sql_toolkit.set_agent_ref(gemini_sql_agent) | |
| logger.info("Agent reference set in toolkit - session_state injection enabled") | |
| else: | |
| logger.info("Running under pytest: skipping heavy SQL agent runtime initialization") | |
| # Define Pydantic model for tenant-aware API requests | |
| class TenantRunRequest(BaseModel): | |
| """ | |
| Request model for our custom tenant-aware endpoint. | |
| This ensures all tenant context is provided in a single, secure request. | |
| Supports multi-source agent auto-detection when available_sources is provided. | |
| """ | |
| message: str | |
| supabase_jwt: str # JWT token for auth | |
| tenant_id: str # Extracted from JWT claims | |
| source_name: str # Default/primary source for query execution | |
| session_id: Optional[str] = None | |
| user_id: Optional[str] = None | |
| available_sources: Optional[list] = None # All available sources for agent auto-detection | |
| stream: bool = False | |
| background: bool = False | |
| # Define the tenant-aware endpoint function (will be added to AgentOS app later) | |
| async def run_tenant_agent( | |
| agent_id: str, | |
| request: TenantRunRequest, | |
| auth_user: AuthUser = Depends(get_current_user), | |
| http_request: Request = None, | |
| ): | |
| """ | |
| Custom endpoint to run an agent with tenant_id, source_name, and supabase_jwt | |
| injected directly into the session_state. | |
| This is the PRIMARY endpoint for multi-tenant agent execution. | |
| It ensures proper tenant isolation and security by: | |
| 1. Accepting all tenant context in the request body | |
| 2. Injecting it into session_state (not shared between requests) | |
| 3. Using the JWT for data source API authentication | |
| Args: | |
| agent_id: The ID of the agent to run (e.g., "sirus-sql-agent") | |
| request: TenantRunRequest containing all tenant context | |
| Returns: | |
| StreamingResponse (if stream=True) or direct JSON response | |
| """ | |
| # Get agent from the global agent we created | |
| agent = gemini_sql_agent if agent_id == DEFAULT_AGENT_ID else None | |
| if not agent: | |
| raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found.") | |
| # Resolve tenant/user context from validated JWT claims. | |
| resolved_tenant_id = (auth_user.tenant_id or "").strip() | |
| if not resolved_tenant_id: | |
| raise HTTPException(status_code=401, detail="Missing tenant_id in JWT claims") | |
| if request.tenant_id and request.tenant_id != resolved_tenant_id: | |
| logger.warning( | |
| f"Rejecting tenant-run due to tenant mismatch. body={request.tenant_id} jwt={resolved_tenant_id}" | |
| ) | |
| raise HTTPException(status_code=403, detail="tenant_id mismatch with authenticated user") | |
| resolved_actor_user_id = auth_user.id or request.user_id | |
| resolved_session_owner_id = resolved_tenant_id | |
| resolved_jwt = request.supabase_jwt | |
| if http_request is not None: | |
| auth_header = http_request.headers.get("Authorization", "") | |
| if auth_header.startswith("Bearer "): | |
| resolved_jwt = auth_header.replace("Bearer ", "", 1).strip() or resolved_jwt | |
| # CRITICAL: This is the state that will be loaded *for this run only*. | |
| # This is the correct, request-safe way to handle per-run context. | |
| # Each request gets its own isolated session_state. | |
| initial_state = { | |
| "supabase_jwt": resolved_jwt, # JWT for backend API auth | |
| "tenant_id": resolved_tenant_id, # Tenant context for toolkit | |
| "source_name": request.source_name, | |
| "user_id": resolved_session_owner_id, | |
| "actor_user_id": resolved_actor_user_id, | |
| "available_sources": request.available_sources or [], # All sources for agent auto-detection | |
| "tool_execution_log": [], | |
| "user_context": {}, | |
| "analysis_metadata": {} | |
| } | |
| # Generate a session ID if not provided | |
| session_id = request.session_id or str(uuid.uuid4()) | |
| logger.info(f"🚀 Starting tenant run for tenant_id={resolved_tenant_id}, source={request.source_name}, session={session_id}") | |
| if request.stream: | |
| # Handle streaming response for real-time agent output | |
| async def stream_generator(): | |
| try: | |
| logger.info(f"🎬 Starting streaming for session {session_id}, message: {request.message[:50]}...") | |
| # Emit a canonical start event so frontend can persist session_id and retain memory across turns. | |
| run_started_payload = { | |
| "event": "RunStarted", | |
| "session_id": session_id, | |
| "agent_id": agent_id, | |
| } | |
| yield f"event: RunStarted\ndata: {json.dumps(run_started_payload)}\n\n" | |
| logger.info(f" ✅ Yielded RunStarted with session_id={session_id}") | |
| # agent.run returns a generator in stream mode | |
| response_generator = agent.run( | |
| request.message, | |
| stream=True, | |
| stream_events=True, # Enable full event streaming for tool calls | |
| session_id=session_id, | |
| user_id=resolved_session_owner_id, # Tag session with owner so get_session(user_id=) works | |
| session_state=initial_state | |
| ) | |
| chunk_count = 0 | |
| for chunk in response_generator: | |
| chunk_count += 1 | |
| # Handle RunEvent types for proper streaming | |
| if hasattr(chunk, 'event'): | |
| logger.info(f" [Chunk {chunk_count}] Event: {chunk.event}") | |
| if chunk.event == RunEvent.run_content: | |
| # Model text response | |
| event_data = {"content": chunk.content if hasattr(chunk, 'content') else str(chunk)} | |
| sse_event = f"event: RunContent\ndata: {json.dumps(event_data)}\n\n" | |
| yield sse_event | |
| logger.info(f" ✅ Yielded RunContent event") | |
| elif chunk.event == RunEvent.tool_call_started: | |
| # Tool starting - send full tool object for frontend | |
| tool_obj = chunk.tool if hasattr(chunk, 'tool') else None | |
| tool_call_id = getattr(tool_obj, 'tool_call_id', None) or str(uuid.uuid4()) | |
| tool_name = getattr(tool_obj, 'tool_name', 'unknown') | |
| tool_args = {} | |
| if hasattr(tool_obj, 'tool_args') and tool_obj.tool_args: | |
| tool_args = tool_obj.tool_args if isinstance(tool_obj.tool_args, dict) else {} | |
| event_data = { | |
| "tool": { | |
| "tool_call_id": tool_call_id, | |
| "tool_name": tool_name, | |
| "tool_args": tool_args, | |
| "role": "tool", | |
| "tool_call_error": False, | |
| "content": None, | |
| "metrics": {"time": 0}, | |
| "created_at": int(time.time()) | |
| }, | |
| "status": "started" | |
| } | |
| sse_event = f"event: ToolCallStarted\ndata: {json.dumps(event_data, default=str)}\n\n" | |
| yield sse_event | |
| logger.info(f" ✅ Yielded ToolCallStarted: {tool_name} (id: {tool_call_id})") | |
| elif chunk.event == RunEvent.tool_call_completed: | |
| # Tool finished - send full tool object with result | |
| tool_obj = chunk.tool if hasattr(chunk, 'tool') else None | |
| tool_call_id = getattr(tool_obj, 'tool_call_id', None) or str(uuid.uuid4()) | |
| tool_name = getattr(tool_obj, 'tool_name', 'unknown') | |
| tool_args = {} | |
| if hasattr(tool_obj, 'tool_args') and tool_obj.tool_args: | |
| tool_args = tool_obj.tool_args if isinstance(tool_obj.tool_args, dict) else {} | |
| # Get the ACTUAL tool result as a raw object (not pre-serialized) | |
| # This ensures proper JSON encoding when we serialize event_data | |
| content = None | |
| content_source = "none" | |
| if tool_obj: | |
| # Try to get the actual result from tool_obj | |
| if hasattr(tool_obj, 'result') and tool_obj.result is not None: | |
| result = tool_obj.result | |
| content_source = "tool_obj.result" | |
| # Keep as raw object for proper serialization | |
| if isinstance(result, (dict, list)): | |
| content = result # Raw object - will be serialized by outer json.dumps | |
| elif isinstance(result, str): | |
| # Try to parse if it's already JSON | |
| try: | |
| content = json.loads(result) | |
| except: | |
| content = result # Keep as string | |
| else: | |
| content = str(result) | |
| elif hasattr(tool_obj, 'content') and tool_obj.content is not None: | |
| tc = tool_obj.content | |
| content_source = "tool_obj.content" | |
| if isinstance(tc, (dict, list)): | |
| content = tc | |
| elif isinstance(tc, str): | |
| try: | |
| content = json.loads(tc) | |
| except: | |
| content = tc | |
| else: | |
| content = str(tc) | |
| # Last fallback - use chunk.content (formatted message) | |
| if content is None and hasattr(chunk, 'content') and chunk.content: | |
| content = str(chunk.content)[:2000] | |
| content_source = "chunk.content" | |
| tool_error = getattr(tool_obj, 'tool_call_error', False) if tool_obj else False | |
| exec_time = getattr(tool_obj, 'metrics', {}) | |
| if hasattr(exec_time, 'time'): | |
| exec_time = exec_time.time | |
| elif isinstance(exec_time, dict): | |
| exec_time = exec_time.get('time', 0) | |
| else: | |
| exec_time = 0 | |
| event_data = { | |
| "tool": { | |
| "tool_call_id": tool_call_id, | |
| "tool_name": tool_name, | |
| "tool_args": tool_args, | |
| "role": "tool", | |
| "tool_call_error": tool_error, | |
| "content": content, # Raw object - properly serialized by json.dumps below | |
| "metrics": {"time": exec_time}, | |
| "created_at": int(time.time()) | |
| }, | |
| "status": "completed" | |
| } | |
| sse_event = f"event: ToolCallCompleted\ndata: {json.dumps(event_data, default=str)}\n\n" | |
| yield sse_event | |
| logger.info(f" ✅ Yielded ToolCallCompleted: {tool_name} (id: {tool_call_id}) content_source: {content_source} content_type: {type(content).__name__}") | |
| elif chunk.event == RunEvent.run_completed: | |
| # Run completed - send metrics | |
| metrics_data = {} | |
| if hasattr(chunk, 'metrics') and chunk.metrics: | |
| # MiniMax M2.5 Estimated Pricing (e.g., $0.15/1M in, $0.60/1M out) | |
| input_tokens = getattr(chunk.metrics, 'input_tokens', 0) | |
| output_tokens = getattr(chunk.metrics, 'output_tokens', 0) | |
| total_tokens = getattr(chunk.metrics, 'total_tokens', 0) | |
| est_cost = (input_tokens / 1_000_000) * 0.15 + (output_tokens / 1_000_000) * 0.60 | |
| metrics_data = { | |
| "input_tokens": input_tokens, | |
| "output_tokens": output_tokens, | |
| "total_tokens": total_tokens, | |
| "time_to_first_token": getattr(chunk.metrics, 'time_to_first_token', 0), | |
| "tokens_per_second": getattr(chunk.metrics, 'tokens_per_second', 0), | |
| "estimated_cost_usd": float(f"{est_cost:.6f}") | |
| } | |
| # Log the comprehensive cost tracking | |
| logger.info(f"💰 [COST TRACKING] Session: {session_id} | Tenant: {resolved_tenant_id} | " | |
| f"Tokens: {input_tokens} In, {output_tokens} Out, {total_tokens} Total | " | |
| f"Est. Cost: ${est_cost:.6f}") | |
| record_tenant_billing(resolved_tenant_id, input_tokens, output_tokens) | |
| # Save to session state analysis metadata | |
| initial_state["analysis_metadata"]["final_cost_metrics"] = metrics_data | |
| event_data = {"metrics": metrics_data, "session_id": session_id} | |
| sse_event = f"event: RunCompleted\ndata: {json.dumps(event_data)}\n\n" | |
| yield sse_event | |
| logger.info(f" ✅ Yielded RunCompleted with metrics: {metrics_data}") | |
| else: | |
| # Other event types | |
| logger.info(f" ⚠️ Unhandled event type: {chunk.event}") | |
| await asyncio.sleep(0.001) | |
| continue | |
| # Fallback for dict-based chunks | |
| if isinstance(chunk, dict): | |
| event = chunk.get("event") | |
| data = chunk.get("data") | |
| if event: | |
| sse_event = f"event: {event}\ndata: {json.dumps(data)}\n\n" | |
| else: | |
| sse_event = f"data: {json.dumps(chunk)}\n\n" | |
| yield sse_event | |
| logger.info(f" ✅ Yielded event: {event or 'data-only'}") | |
| # Small delay to ensure chunk is flushed before next one | |
| await asyncio.sleep(0.001) | |
| else: | |
| # Handle Pydantic objects or other objects | |
| try: | |
| logger.info(f"Processing chunk type: {type(chunk)}") | |
| # Try multiple serialization methods | |
| chunk_dict = None | |
| # Method 1: Pydantic v2 model_dump() | |
| if hasattr(chunk, 'model_dump'): | |
| try: | |
| chunk_dict = chunk.model_dump() | |
| logger.info(f"✅ Serialized with model_dump()") | |
| except Exception as e: | |
| logger.info(f"model_dump() failed: {e}") | |
| # Method 2: Pydantic v1 dict() | |
| if chunk_dict is None and hasattr(chunk, 'dict'): | |
| try: | |
| chunk_dict = chunk.dict() | |
| logger.info(f"✅ Serialized with dict()") | |
| except Exception as e: | |
| logger.info(f"dict() failed: {e}") | |
| # Method 3: Check if it's a Pydantic BaseModel | |
| if chunk_dict is None: | |
| try: | |
| # Try to import and check | |
| from pydantic import BaseModel | |
| if isinstance(chunk, BaseModel): | |
| chunk_dict = chunk.model_dump() | |
| logger.info(f"✅ Serialized BaseModel with model_dump()") | |
| except Exception as e: | |
| logger.info(f"BaseModel check failed: {e}") | |
| # Method 4: Fall back to __dict__ | |
| if chunk_dict is None and hasattr(chunk, '__dict__'): | |
| chunk_dict = chunk.__dict__ | |
| logger.info(f"✅ Serialized with __dict__") | |
| # Method 5: Last resort - convert to string | |
| if chunk_dict is None: | |
| logger.warning(f"Could not serialize chunk, converting to string: {type(chunk)}") | |
| chunk_dict = {"content": str(chunk)} | |
| # Extract event type if present | |
| event_type = chunk_dict.get("event") | |
| if event_type: | |
| logger.info(f"Sending event: {event_type}") | |
| # Debug: Show content for ReasoningStep events | |
| if event_type == "ReasoningStep": | |
| logger.info(f" ReasoningStep content: reasoning={chunk_dict.get('reasoning')}, content={chunk_dict.get('content')}, result={chunk_dict.get('result')}") | |
| logger.info(f" Full ReasoningStep dict keys: {list(chunk_dict.keys())}") | |
| # Use custom serializer that properly handles nested objects | |
| def serialize_value(obj): | |
| """Recursively serialize objects, converting to strings only when necessary""" | |
| if isinstance(obj, dict): | |
| return {k: serialize_value(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return [serialize_value(v) for v in obj] | |
| elif hasattr(obj, 'model_dump'): | |
| return serialize_value(obj.model_dump()) | |
| elif hasattr(obj, '__dict__') and not isinstance(obj, (str, int, float, bool, type(None))): | |
| return serialize_value(obj.__dict__) | |
| else: | |
| return obj | |
| serialized_dict = serialize_value(chunk_dict) | |
| # Special handling for ReasoningStep: convert content object to string | |
| if event_type == "ReasoningStep" and isinstance(serialized_dict.get("content"), dict): | |
| # Content is a reasoning object - serialize it as string for frontend | |
| reasoning_obj = serialized_dict.pop("content") | |
| serialized_dict["reasoning_content"] = json.dumps(reasoning_obj, default=str, ensure_ascii=False) | |
| logger.info(f" ✅ Converted ReasoningStep content to reasoning_content string") | |
| sse_event = f"event: {event_type}\ndata: {json.dumps(serialized_dict, default=str, ensure_ascii=False)}\n\n" | |
| else: | |
| logger.info(f"Sending data without event type") | |
| def serialize_value(obj): | |
| """Recursively serialize objects, converting to strings only when necessary""" | |
| if isinstance(obj, dict): | |
| return {k: serialize_value(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return [serialize_value(v) for v in obj] | |
| elif hasattr(obj, 'model_dump'): | |
| return serialize_value(obj.model_dump()) | |
| elif hasattr(obj, '__dict__') and not isinstance(obj, (str, int, float, bool, type(None))): | |
| return serialize_value(obj.__dict__) | |
| else: | |
| return obj | |
| serialized_dict = serialize_value(chunk_dict) | |
| sse_event = f"data: {json.dumps(serialized_dict, default=str, ensure_ascii=False)}\n\n" | |
| yield sse_event | |
| logger.info(f" ✅ Yielded event: {event_type or 'data-only'}") | |
| # Small delay to ensure chunk is flushed before next one | |
| await asyncio.sleep(0.001) | |
| except Exception as e: | |
| logger.error(f"Failed to serialize chunk: {e}, chunk type: {type(chunk)}", exc_info=True) | |
| yield f"data: {json.dumps({'error': str(e), 'content': str(chunk)}, default=str)}\n\n" | |
| await asyncio.sleep(0.001) | |
| logger.info(f"✅ Streaming run completed for session {session_id} - sent {chunk_count} chunks") | |
| except Exception as e: | |
| logger.error(f"❌ Error during stream generation for session {session_id}: {e}", exc_info=True) | |
| error_data = {"error": str(e), "code": "STREAM_ERROR"} | |
| yield f"event: error\ndata: {json.dumps(error_data)}\n\n" | |
| return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
| else: | |
| # Handle non-streaming (blocking) response | |
| try: | |
| response = agent.run( | |
| request.message, | |
| stream=False, | |
| session_id=session_id, | |
| user_id=resolved_session_owner_id, | |
| session_state=initial_state | |
| ) | |
| # Non-streaming Cost tracking | |
| metrics_data = {} | |
| if hasattr(response, 'metrics') and response.metrics: | |
| input_tokens = getattr(response.metrics, 'input_tokens', 0) | |
| output_tokens = getattr(response.metrics, 'output_tokens', 0) | |
| total_tokens = getattr(response.metrics, 'total_tokens', 0) | |
| est_cost = (input_tokens / 1_000_000) * 0.15 + (output_tokens / 1_000_000) * 0.60 | |
| logger.info(f"💰 [COST TRACKING] Session: {session_id} | Tenant: {resolved_tenant_id} | " | |
| f"Tokens: {input_tokens} In, {output_tokens} Out, {total_tokens} Total | " | |
| f"Est. Cost: ${est_cost:.6f}") | |
| record_tenant_billing(resolved_tenant_id, input_tokens, output_tokens) | |
| logger.info(f"✅ Non-streaming run completed for session {session_id}") | |
| # The final response from agent.run is the message content | |
| return { | |
| "session_id": session_id, | |
| "tenant_id": resolved_tenant_id, | |
| "response": response | |
| } | |
| except Exception as e: | |
| logger.error(f"❌ Error during non-streaming agent run for session {session_id}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if gemini_sql_agent is not None: | |
| agent_os = AgentOS( | |
| agents=[gemini_sql_agent], | |
| description="Multi-tenant SQL Agent for querying data sources across tenants." | |
| ) | |
| agentOS_app = agent_os.get_app() | |
| agentOS_app.add_api_route( | |
| "/tenant-run/{agent_id}", | |
| run_tenant_agent, | |
| methods=["POST"], | |
| name="run_tenant_agent" | |
| ) | |
| app = agentOS_app | |
| else: | |
| agent_os = None | |
| app = FastAPI() | |
| # ============================================================================ | |
| # Chat CRUD Endpoints (from agentOS_crud.md) | |
| # ============================================================================ | |
| def _serialize_session_obj(session_obj: Any) -> Dict[str, Any]: | |
| if hasattr(session_obj, "model_dump"): | |
| data = session_obj.model_dump() | |
| if isinstance(data, dict): | |
| return data | |
| if isinstance(session_obj, dict): | |
| return session_obj | |
| if hasattr(session_obj, "__dict__"): | |
| return dict(session_obj.__dict__) | |
| return {"value": str(session_obj)} | |
| def _extract_session_id(session_payload: Dict[str, Any]) -> str: | |
| return str( | |
| session_payload.get("session_id") | |
| or session_payload.get("id") | |
| or session_payload.get("sessionId") | |
| or "" | |
| ) | |
| def _ensure_agent_runtime_ready() -> None: | |
| if agent_db is None or gemini_sql_agent is None: | |
| raise HTTPException(status_code=503, detail="Agent runtime is not initialized") | |
| def _session_belongs_to_user(session_id: str, user_id: str) -> bool: | |
| """Check that session_id belongs to user_id. | |
| Two-pass approach for robustness: | |
| 1. Try get_session(user_id=user_id) — works for sessions saved with user_id. | |
| 2. Fall back to get_session() without user_id and verify the stored user_id | |
| matches (or is unset, which we allow for legacy sessions). | |
| """ | |
| if not session_id or not user_id or gemini_sql_agent is None: | |
| return False | |
| try: | |
| # Pass 1: user-scoped lookup (ideal path) | |
| session_obj = gemini_sql_agent.get_session(session_id=session_id, user_id=user_id) | |
| if session_obj is not None: | |
| payload = _serialize_session_obj(session_obj) | |
| resolved_session_id = _extract_session_id(payload) | |
| return bool(resolved_session_id and resolved_session_id == session_id) | |
| # Pass 2: session-only lookup — handles sessions where user_id was not saved | |
| session_obj = gemini_sql_agent.get_session(session_id=session_id) | |
| if session_obj is None: | |
| return False | |
| payload = _serialize_session_obj(session_obj) | |
| resolved_session_id = _extract_session_id(payload) | |
| if not resolved_session_id or resolved_session_id != session_id: | |
| return False | |
| # Accept if stored user_id matches OR is blank (legacy / first run before fix) | |
| stored_uid = str(payload.get("user_id") or "").strip() | |
| return (not stored_uid) or (stored_uid == user_id) | |
| except Exception as exc: | |
| logger.error(f"Failed ownership check for session {session_id}: {exc}") | |
| return False | |
| def _serialize_chat_message_sql(m) -> Dict[str, Any]: | |
| """Serialize an Agno Message to a rich dict for frontend turn reconstruction. | |
| Returns role, content, tool_calls (LLM call requests on assistant msgs), | |
| tool_call_id / tool_name (on tool-result msgs), and created_at. The | |
| frontend uses these to rebuild the streaming-equivalent ChatMessage | |
| structure (tool_calls + sqlExecutions) from DB history. | |
| """ | |
| role = str(getattr(m, "role", "user") or "user") | |
| raw_content = getattr(m, "content", None) | |
| if raw_content is None: | |
| content = "" | |
| elif isinstance(raw_content, str): | |
| content = raw_content | |
| else: | |
| try: | |
| content = json.dumps(raw_content) | |
| except Exception: | |
| content = str(raw_content) | |
| result: Dict[str, Any] = {"role": role, "content": content} | |
| created_at = getattr(m, "created_at", None) | |
| if created_at is not None: | |
| result["created_at"] = created_at | |
| # tool_calls: present on assistant messages that requested tool calls | |
| tool_calls = getattr(m, "tool_calls", None) | |
| if tool_calls: | |
| serialized_tcs = [] | |
| for tc in tool_calls: | |
| try: | |
| if isinstance(tc, dict): | |
| tc_id = str(tc.get("id") or "") | |
| fn = tc.get("function") or {} | |
| fn_name = str(fn.get("name") or "") | |
| fn_args = str(fn.get("arguments") or "{}") | |
| tc_type = str(tc.get("type") or "function") | |
| else: | |
| tc_id = str(getattr(tc, "id", "") or "") | |
| fn_obj = getattr(tc, "function", None) | |
| fn_name = str(getattr(fn_obj, "name", "") if fn_obj else "") | |
| fn_args = str(getattr(fn_obj, "arguments", "{}") if fn_obj else "{}") | |
| tc_type = str(getattr(tc, "type", "function") or "function") | |
| serialized_tcs.append({"id": tc_id, "type": tc_type, "function": {"name": fn_name, "arguments": fn_args}}) | |
| except Exception: | |
| continue | |
| if serialized_tcs: | |
| result["tool_calls"] = serialized_tcs | |
| # tool_call_id + tool_name: on tool-role messages (the result) | |
| tool_call_id = getattr(m, "tool_call_id", None) | |
| if tool_call_id: | |
| result["tool_call_id"] = str(tool_call_id) | |
| name = getattr(m, "name", None) | |
| if name: | |
| result["tool_name"] = str(name) | |
| return result | |
| async def list_user_sessions(user_id: str, auth_user: AuthUser = Depends(get_current_user)): | |
| """LIST sessions for a user.""" | |
| _ensure_agent_runtime_ready() | |
| requester_tenant_id = (auth_user.tenant_id or "").strip() | |
| if not requester_tenant_id: | |
| raise HTTPException(status_code=401, detail="Missing tenant_id in JWT claims") | |
| if requester_tenant_id != user_id: | |
| raise HTTPException(status_code=403, detail="Forbidden: tenant_id mismatch") | |
| try: | |
| sessions = agent_db.get_sessions(user_id=user_id, component_id=DEFAULT_AGENT_ID, limit=200) | |
| serialized = [_serialize_session_obj(s) for s in (sessions or [])] | |
| # Enrich each session with normalised fields the frontend sidebar needs | |
| enriched = [] | |
| for s in serialized: | |
| sid = _extract_session_id(s) | |
| enriched.append({ | |
| **s, | |
| "session_id": sid, | |
| "name": s.get("session_name") or s.get("name") or f"Chat {sid[:8]}", | |
| "created_at": s.get("created_at"), | |
| }) | |
| return {"sessions": enriched} | |
| except Exception as e: | |
| logger.error(f"Failed to list sessions for user {user_id}: {e}") | |
| return {"sessions": []} | |
| async def get_chat(user_id: str, session_id: str, auth_user: AuthUser = Depends(get_current_user)): | |
| """GET chat history for a session — returns rich message data including tool call info.""" | |
| _ensure_agent_runtime_ready() | |
| requester_tenant_id = (auth_user.tenant_id or "").strip() | |
| if not requester_tenant_id: | |
| raise HTTPException(status_code=401, detail="Missing tenant_id in JWT claims") | |
| if requester_tenant_id != user_id: | |
| raise HTTPException(status_code=403, detail="Forbidden: tenant_id mismatch") | |
| if not _session_belongs_to_user(session_id=session_id, user_id=user_id): | |
| raise HTTPException(status_code=404, detail="Chat session not found") | |
| try: | |
| chat = gemini_sql_agent.get_chat_history(session_id=session_id) | |
| if not chat: | |
| return {"messages": [], "status": "completed"} | |
| return { | |
| "messages": [_serialize_chat_message_sql(m) for m in chat], | |
| "status": "completed", | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to get chat for session {session_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to retrieve chat history") | |
| async def delete_chat(session_id: str, auth_user: AuthUser = Depends(get_current_user)): | |
| """DELETE a session (and all its runs).""" | |
| _ensure_agent_runtime_ready() | |
| requester_tenant_id = (auth_user.tenant_id or "").strip() | |
| if not requester_tenant_id: | |
| raise HTTPException(status_code=401, detail="Missing tenant_id in JWT claims") | |
| if not _session_belongs_to_user(session_id=session_id, user_id=requester_tenant_id): | |
| raise HTTPException(status_code=404, detail="Chat session not found") | |
| try: | |
| gemini_sql_agent.delete_session(session_id=session_id, user_id=requester_tenant_id) | |
| return {"status": "deleted"} | |
| except Exception as e: | |
| logger.error(f"Failed to delete session {session_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to delete session") | |
| async def rename_chat(session_id: str, name: str = Body(..., embed=True), auth_user: AuthUser = Depends(get_current_user)): | |
| """RENAME a session.""" | |
| _ensure_agent_runtime_ready() | |
| requester_tenant_id = (auth_user.tenant_id or "").strip() | |
| if not requester_tenant_id: | |
| raise HTTPException(status_code=401, detail="Missing tenant_id in JWT claims") | |
| if not _session_belongs_to_user(session_id=session_id, user_id=requester_tenant_id): | |
| raise HTTPException(status_code=404, detail="Chat session not found") | |
| try: | |
| gemini_sql_agent.set_session_name(session_id=session_id, session_name=name) | |
| return {"status": "renamed"} | |
| except Exception as e: | |
| logger.error(f"Failed to rename session {session_id}: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to rename session") | |
| async def cancel_run(session_id: str, run_id: str, auth_user: AuthUser = Depends(get_current_user)): | |
| """CANCEL a running agent.""" | |
| _ensure_agent_runtime_ready() | |
| requester_tenant_id = (auth_user.tenant_id or "").strip() | |
| if not requester_tenant_id: | |
| raise HTTPException(status_code=401, detail="Missing tenant_id in JWT claims") | |
| if not _session_belongs_to_user(session_id=session_id, user_id=requester_tenant_id): | |
| raise HTTPException(status_code=404, detail="Chat session not found") | |
| try: | |
| # Some versions of Agno support cancel_run | |
| success = False | |
| if hasattr(gemini_sql_agent, 'cancel_run'): | |
| success = gemini_sql_agent.cancel_run(run_id) | |
| return {"cancelled": success} | |
| except Exception as e: | |
| return {"cancelled": False, "error": str(e)} | |
| # DEPRECATED FUNCTIONS - Replaced by the /tenant-run API endpoint | |
| # The following functions are kept for backward compatibility and local testing only. | |
| # For production API usage, use the /tenant-run/{agent_id} endpoint instead. | |
| if __name__ == "__main__": | |
| import uvicorn | |
| host = os.getenv("SQL_AGENT_HOST", "0.0.0.0") | |
| port = int(os.getenv("SQL_AGENT_PORT", "5559")) # Override with SQL_AGENT_PORT=8000 for unified | |
| print("\n" + "="*80) | |
| print("🚀 STARTING SQL AGENT OS SERVER (with custom /tenant-run endpoint)") | |
| print("="*80) | |
| print(f"Host: {host}") | |
| print(f"Port: {port}") | |
| print(f"Agent ID: {DEFAULT_AGENT_ID}") | |
| print(f"AgentOS ID: {DEFAULT_AGENT_OS_ID}") | |
| print("="*80 + "\n") | |
| print(f"\n🎯 CUSTOM TENANT ENDPOINT:") | |
| print(f" POST http://{host}:{port}/tenant-run/{DEFAULT_AGENT_ID}") | |
| print(f"\n📚 STANDARD AGENTOS ENDPOINTS:") | |
| print(f" GET http://{host}:{port}/config") | |
| print(f" GET http://{host}:{port}/agents") | |
| print(f" POST http://{host}:{port}/agents/{DEFAULT_AGENT_ID}/runs") | |
| print("="*80 + "\n") | |
| # Run with proper streaming settings | |
| uvicorn.run( | |
| app, | |
| host=host, | |
| port=port, | |
| # Streaming settings - prevent buffering | |
| server_header=False, | |
| # Disable app level buffering - let streaming work properly | |
| interface="auto" | |
| ) |