| import asyncio |
| import json |
| import os |
| import time |
| from typing import Any |
| from urllib.parse import urlparse |
|
|
| import requests |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse |
|
|
| app = FastAPI() |
|
|
| MODAL_ENDPOINT = os.environ["MODAL_ENDPOINT"] |
| MODAL_KEY = os.environ.get("MODAL_KEY", "") |
| MODAL_SECRET = os.environ.get("MODAL_SECRET", "") |
| WEBHOOK_SECRET = os.environ.get("WEBHOOK_SECRET", "") |
| MAX_TOOL_ROUNDS = int(os.environ.get("MAX_TOOL_ROUNDS", "5")) |
| SESSION_TTL = int(os.environ.get("SESSION_TTL", "3600")) |
| MAX_SESSIONS = int(os.environ.get("MAX_SESSIONS", "1000")) |
| MAX_HISTORY = int(os.environ.get("MAX_HISTORY", "50")) |
|
|
| sessions: dict[int, list[dict]] = {} |
| session_timestamps: dict[int, float] = {} |
|
|
| SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools. Use them when needed to answer questions accurately. |
| |
| Available tools: |
| - web_search: Search the web for information |
| - web_fetch: Fetch content from a URL |
| - browser_navigate: Navigate to a URL and return page content |
| |
| When you need to use a tool, output a JSON tool call in this format: |
| {"tool": "tool_name", "args": {"arg1": "value1"}} |
| |
| After receiving tool results, continue reasoning and provide a final answer. |
| When you have enough information, provide your final answer directly without tool calls.""" |
|
|
| BLOCKED_HOSTS = { |
| "169.254.169.254", |
| "metadata.google.internal", |
| "100.100.100.200", |
| "localhost", |
| "127.0.0.1", |
| "0.0.0.0", |
| } |
|
|
|
|
| def _is_safe_url(url: str) -> bool: |
| try: |
| parsed = urlparse(url) |
| if parsed.scheme not in ("http", "https"): |
| return False |
| hostname = parsed.hostname or "" |
| if hostname in BLOCKED_HOSTS: |
| return False |
| if hostname.startswith("10.") or hostname.startswith("172."): |
| parts = hostname.split(".") |
| if len(parts) == 4: |
| try: |
| if parts[0] == "172" and 16 <= int(parts[1]) <= 31: |
| return False |
| if parts[0] == "10": |
| return False |
| except ValueError: |
| pass |
| if hostname.startswith("192.168."): |
| return False |
| if hostname.startswith("fc") or hostname.startswith("fd"): |
| return False |
| return True |
| except Exception: |
| return False |
|
|
|
|
| def _strip_code_fences(text: str) -> str: |
| stripped = text.strip() |
| if stripped.startswith("```"): |
| first_newline = stripped.index("\n") if "\n" in stripped else len(stripped) |
| stripped = stripped[first_newline + 1 :] |
| if stripped.endswith("```"): |
| stripped = stripped[:-3] |
| return stripped.strip() |
|
|
|
|
| def execute_tool(name: str, args: dict[str, Any]) -> str: |
| if name == "web_search": |
| return _web_search(args.get("query", "")) |
| elif name == "web_fetch": |
| return _web_fetch(args.get("url", "")) |
| elif name == "browser_navigate": |
| return _browser_navigate(args.get("url", "")) |
| return f"Unknown tool: {name}" |
|
|
|
|
| def _web_search(query: str) -> str: |
| try: |
| resp = requests.get( |
| "https://api.duckduckgo.com/", |
| params={"q": query, "format": "json"}, |
| timeout=10, |
| ) |
| data = resp.json() |
| results = [] |
| for r in data.get("RelatedTopics", [])[:5]: |
| if isinstance(r, dict) and "Text" in r: |
| results.append(r["Text"]) |
| return "\n".join(results) if results else "No results found" |
| except Exception: |
| return "Search error" |
|
|
|
|
| def _web_fetch(url: str) -> str: |
| if not _is_safe_url(url): |
| return "Error: URL not allowed" |
| try: |
| resp = requests.get(url, timeout=15, headers={"User-Agent": "Mozilla/5.0"}) |
| return resp.text[:5000] |
| except Exception: |
| return "Fetch error" |
|
|
|
|
| def _browser_navigate(url: str) -> str: |
| if not _is_safe_url(url): |
| return "Error: URL not allowed" |
| try: |
| from playwright.sync_api import sync_playwright |
|
|
| with sync_playwright() as p: |
| browser = p.chromium.launch(headless=True) |
| page = browser.new_page() |
| page.goto(url, timeout=30000, wait_until="domcontentloaded") |
| content = page.content() |
| browser.close() |
| return content[:5000] |
| except Exception: |
| return "Browser error" |
|
|
|
|
| def _cleanup_sessions(): |
| now = time.time() |
| expired = [cid for cid, ts in session_timestamps.items() if now - ts > SESSION_TTL] |
| for cid in expired: |
| sessions.pop(cid, None) |
| session_timestamps.pop(cid, None) |
|
|
|
|
| def get_session(chat_id: int) -> list[dict]: |
| _cleanup_sessions() |
| if len(sessions) >= MAX_SESSIONS and chat_id not in sessions: |
| oldest = min(session_timestamps, key=session_timestamps.get) |
| sessions.pop(oldest, None) |
| session_timestamps.pop(oldest, None) |
| session_timestamps[chat_id] = time.time() |
| return sessions.setdefault(chat_id, []) |
|
|
|
|
| def call_modal(messages: list[dict]) -> str: |
| headers = {"Content-Type": "application/json"} |
| if MODAL_KEY: |
| headers["Modal-Key"] = MODAL_KEY |
| headers["Modal-Secret"] = MODAL_SECRET |
|
|
| resp = requests.post( |
| MODAL_ENDPOINT, |
| headers=headers, |
| json={ |
| "model": "llm", |
| "messages": messages, |
| "max_tokens": 4096, |
| }, |
| timeout=300, |
| ) |
| resp.raise_for_status() |
| return resp.json()["choices"][0]["message"]["content"] |
|
|
|
|
| def agentic_loop(chat_id: int, user_message: str) -> str: |
| history = get_session(chat_id) |
| history.append({"role": "user", "content": user_message}) |
|
|
| if len(history) > MAX_HISTORY: |
| history[:] = history[-MAX_HISTORY:] |
|
|
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history |
|
|
| for _ in range(MAX_TOOL_ROUNDS): |
| content = call_modal(messages) |
|
|
| cleaned = _strip_code_fences(content) |
| try: |
| parsed = json.loads(cleaned) |
| if isinstance(parsed, dict) and "tool" in parsed: |
| tool_name = parsed["tool"] |
| tool_args = parsed.get("args", {}) |
| result = execute_tool(tool_name, tool_args) |
|
|
| messages.append({"role": "assistant", "content": content}) |
| messages.append( |
| {"role": "tool", "content": result[:3000], "name": tool_name} |
| ) |
| continue |
| except (json.JSONDecodeError, TypeError): |
| pass |
|
|
| history.append({"role": "assistant", "content": content}) |
| return content |
|
|
| history.append({"role": "assistant", "content": "Max tool rounds reached."}) |
| return "Max tool rounds reached." |
|
|
|
|
| def _handle_webhook(data: dict) -> dict: |
| chat_id = data.get("chat_id") |
| message = data.get("message") |
|
|
| if not chat_id or not message: |
| return {"error": "chat_id and message required"} |
|
|
| try: |
| response = agentic_loop(chat_id, message) |
| return {"response": response} |
| except Exception: |
| return {"error": "Internal error"} |
|
|
|
|
| @app.post("/webhook") |
| async def webhook(request: Request): |
| if WEBHOOK_SECRET: |
| auth = request.headers.get("Authorization", "") |
| if auth != f"Bearer {WEBHOOK_SECRET}": |
| return JSONResponse({"error": "Unauthorized"}, status_code=401) |
|
|
| data = await request.json() |
| result = await asyncio.to_thread(_handle_webhook, data) |
|
|
| if "error" in result: |
| return JSONResponse(result, status_code=500) |
| return result |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok"} |
|
|