import asyncio import json import os import time from typing import Any from urllib.parse import urlparse import requests from fastapi import FastAPI, Request from fastapi.responses import JSONResponse app = FastAPI() MODAL_ENDPOINT = os.environ["MODAL_ENDPOINT"] MODAL_KEY = os.environ.get("MODAL_KEY", "") MODAL_SECRET = os.environ.get("MODAL_SECRET", "") WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET", "") MAX_TOOL_ROUNDS = int(os.environ.get("MAX_TOOL_ROUNDS", "5")) SESSION_TTL = int(os.environ.get("SESSION_TTL", "3600")) MAX_SESSIONS = int(os.environ.get("MAX_SESSIONS", "1000")) MAX_HISTORY = int(os.environ.get("MAX_HISTORY", "50")) sessions: dict[int, list[dict]] = {} session_timestamps: dict[int, float] = {} SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools. Use them when needed to answer questions accurately. Available tools: - web_search: Search the web for information - web_fetch: Fetch content from a URL - browser_navigate: Navigate to a URL and return page content When you need to use a tool, output a JSON tool call in this format: {"tool": "tool_name", "args": {"arg1": "value1"}} After receiving tool results, continue reasoning and provide a final answer. When you have enough information, provide your final answer directly without tool calls.""" BLOCKED_HOSTS = { "169.254.169.254", "metadata.google.internal", "100.100.100.200", "localhost", "127.0.0.1", "0.0.0.0", } def _is_safe_url(url: str) -> bool: try: parsed = urlparse(url) if parsed.scheme not in ("http", "https"): return False hostname = parsed.hostname or "" if hostname in BLOCKED_HOSTS: return False if hostname.startswith("10.") or hostname.startswith("172."): parts = hostname.split(".") if len(parts) == 4: try: if parts[0] == "172" and 16 <= int(parts[1]) <= 31: return False if parts[0] == "10": return False except ValueError: pass if hostname.startswith("192.168."): return False if hostname.startswith("fc") or hostname.startswith("fd"): return False return True except Exception: return False def _strip_code_fences(text: str) -> str: stripped = text.strip() if stripped.startswith("```"): first_newline = stripped.index("\n") if "\n" in stripped else len(stripped) stripped = stripped[first_newline + 1 :] if stripped.endswith("```"): stripped = stripped[:-3] return stripped.strip() def execute_tool(name: str, args: dict[str, Any]) -> str: if name == "web_search": return _web_search(args.get("query", "")) elif name == "web_fetch": return _web_fetch(args.get("url", "")) elif name == "browser_navigate": return _browser_navigate(args.get("url", "")) return f"Unknown tool: {name}" def _web_search(query: str) -> str: try: resp = requests.get( "https://api.duckduckgo.com/", params={"q": query, "format": "json"}, timeout=10, ) data = resp.json() results = [] for r in data.get("RelatedTopics", [])[:5]: if isinstance(r, dict) and "Text" in r: results.append(r["Text"]) return "\n".join(results) if results else "No results found" except Exception: return "Search error" def _web_fetch(url: str) -> str: if not _is_safe_url(url): return "Error: URL not allowed" try: resp = requests.get(url, timeout=15, headers={"User-Agent": "Mozilla/5.0"}) return resp.text[:5000] except Exception: return "Fetch error" def _browser_navigate(url: str) -> str: if not _is_safe_url(url): return "Error: URL not allowed" try: from playwright.sync_api import sync_playwright with sync_playwright() as p: browser = p.chromium.launch(headless=True) page = browser.new_page() page.goto(url, timeout=30000, wait_until="domcontentloaded") content = page.content() browser.close() return content[:5000] except Exception: return "Browser error" def _cleanup_sessions(): now = time.time() expired = [cid for cid, ts in session_timestamps.items() if now - ts > SESSION_TTL] for cid in expired: sessions.pop(cid, None) session_timestamps.pop(cid, None) def get_session(chat_id: int) -> list[dict]: _cleanup_sessions() if len(sessions) >= MAX_SESSIONS and chat_id not in sessions: oldest = min(session_timestamps, key=session_timestamps.get) sessions.pop(oldest, None) session_timestamps.pop(oldest, None) session_timestamps[chat_id] = time.time() return sessions.setdefault(chat_id, []) def call_modal(messages: list[dict]) -> str: headers = {"Content-Type": "application/json"} if MODAL_KEY: headers["Modal-Key"] = MODAL_KEY headers["Modal-Secret"] = MODAL_SECRET resp = requests.post( MODAL_ENDPOINT, headers=headers, json={ "model": "llm", "messages": messages, "max_tokens": 4096, }, timeout=300, ) resp.raise_for_status() return resp.json()["choices"][0]["message"]["content"] def agentic_loop(chat_id: int, user_message: str) -> str: history = get_session(chat_id) history.append({"role": "user", "content": user_message}) if len(history) > MAX_HISTORY: history[:] = history[-MAX_HISTORY:] messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history for _ in range(MAX_TOOL_ROUNDS): content = call_modal(messages) cleaned = _strip_code_fences(content) try: parsed = json.loads(cleaned) if isinstance(parsed, dict) and "tool" in parsed: tool_name = parsed["tool"] tool_args = parsed.get("args", {}) result = execute_tool(tool_name, tool_args) messages.append({"role": "assistant", "content": content}) messages.append( {"role": "tool", "content": result[:3000], "name": tool_name} ) continue except (json.JSONDecodeError, TypeError): pass history.append({"role": "assistant", "content": content}) return content history.append({"role": "assistant", "content": "Max tool rounds reached."}) return "Max tool rounds reached." def _handle_webhook(data: dict) -> dict: chat_id = data.get("chat_id") message = data.get("message") if not chat_id or not message: return {"error": "chat_id and message required"} try: response = agentic_loop(chat_id, message) return {"response": response} except Exception: return {"error": "Internal error"} @app.post("/webhook") async def webhook(request: Request): if WEBHOOK_SECRET: auth = request.headers.get("Authorization", "") if auth != f"Bearer {WEBHOOK_SECRET}": return JSONResponse({"error": "Unauthorized"}, status_code=401) data = await request.json() result = await asyncio.to_thread(_handle_webhook, data) if "error" in result: return JSONResponse(result, status_code=500) return result @app.get("/health") async def health(): return {"status": "ok"}