| |
| from __future__ import annotations |
|
|
| import json |
| import time |
| from typing import Any |
| from urllib.parse import urlparse |
|
|
| from gradio import ChatMessage |
|
|
| from apis.country_profile import get_country_profiles |
| from apis.exa import search_immigration |
| from apis.firecrawl import crawl_site, scrape_page |
| from ui.globe_commands import apply_update_globe |
|
|
| from .messages import assistant_message_dict |
| from .traces import record_tool_trace |
|
|
|
|
| def truncate(text: str, limit: int = 4000) -> str: |
| if len(text) <= limit: |
| return text |
| return text[:limit] + "\n… (truncated)" |
|
|
|
|
| def run_tool( |
| name: str, |
| arguments: str, |
| *, |
| globe_state: dict[str, Any] | None = None, |
| ) -> tuple[str, dict[str, Any] | None]: |
| try: |
| args = json.loads(arguments or "{}") |
| result, updated_globe = _run_tool(name, args, globe_state=globe_state) |
| return json.dumps(result, default=str), updated_globe |
| except Exception as exc: |
| return json.dumps({"error": str(exc)}), globe_state |
|
|
|
|
| def _run_tool( |
| name: str, |
| args: dict[str, Any], |
| *, |
| globe_state: dict[str, Any] | None = None, |
| ) -> tuple[Any, dict[str, Any] | None]: |
| if name == "think": |
| return {"ok": True}, globe_state |
| if name == "get_country_profile": |
| return get_country_profiles(args["countries"]), globe_state |
| if name == "search_immigration_info": |
| return ( |
| search_immigration( |
| query=args["query"], |
| num_results=min(args.get("num_results", 8), 15), |
| country=args.get("country"), |
| include_domains=args.get("include_domains"), |
| ), |
| globe_state, |
| ) |
| if name == "scrape_web_page": |
| return ( |
| scrape_page( |
| url=args["url"], |
| country=args.get("country"), |
| ), |
| globe_state, |
| ) |
| if name == "crawl_web_site": |
| return ( |
| crawl_site( |
| url=args["url"], |
| limit=min(args.get("limit", 10), 20), |
| include_paths=args.get("include_paths"), |
| country=args.get("country"), |
| ), |
| globe_state, |
| ) |
| if name == "update_globe": |
| if globe_state is None: |
| return {"error": "Globe state is unavailable"}, None |
| return apply_update_globe(globe_state, args) |
|
|
| return {"error": f"Unknown tool: {name}"}, globe_state |
|
|
|
|
| def _parse_arguments(arguments: str) -> dict[str, Any]: |
| try: |
| parsed = json.loads(arguments or "{}") |
| return parsed if isinstance(parsed, dict) else {} |
| except json.JSONDecodeError: |
| return {} |
|
|
|
|
| def _load_result(result: str) -> dict[str, Any]: |
| try: |
| parsed = json.loads(result) |
| return parsed if isinstance(parsed, dict) else {} |
| except json.JSONDecodeError: |
| return {} |
|
|
|
|
| def _normalized_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]: |
| """Keep replayed assistant tool calls valid for the chat-completion API.""" |
| normalized = [] |
| for tool_call in tool_calls: |
| function = tool_call.get("function") or {} |
| normalized.append( |
| { |
| **tool_call, |
| "function": { |
| **function, |
| "arguments": json.dumps( |
| _parse_arguments(str(function.get("arguments") or "")) |
| ), |
| }, |
| } |
| ) |
| return normalized |
|
|
|
|
| def _join(items: list[str], fallback: str) -> str: |
| clean = [str(item) for item in items if item] |
| return ", ".join(clean) if clean else fallback |
|
|
|
|
| def _url_host_or_path(url: str, *, limit: int = 80) -> str: |
| raw = str(url or "").strip() |
| if not raw: |
| return "page" |
| try: |
| parsed = urlparse(raw if "://" in raw else f"https://{raw}") |
| host = parsed.netloc |
| path = parsed.path.strip("/") |
| if host and path: |
| return truncate(f"{host}/{path}", limit) |
| return truncate(host or path or raw, limit) |
| except Exception: |
| return truncate(raw, limit) |
|
|
|
|
| def tool_display_title(tool_name: str, args: dict[str, Any]) -> str: |
| if tool_name == "get_country_profile": |
| countries = _join(args.get("countries") or [], "selected countries") |
| return f"Profiles · {countries}" |
| if tool_name == "search_immigration_info": |
| query = truncate(str(args.get("query") or "immigration sources"), 80) |
| return f"Search · {query}" |
| if tool_name == "scrape_web_page": |
| return f"Read · {_url_host_or_path(str(args.get('url') or ''))}" |
| if tool_name == "crawl_web_site": |
| return f"Crawl · {_url_host_or_path(str(args.get('url') or ''), limit=60)}" |
| if tool_name == "update_globe": |
| countries = _join(args.get("countries") or [], "selected countries") |
| return f"Globe · {countries}" |
| return f"Using {tool_name}" |
|
|
|
|
| def _pending_tool_message(tool_name: str, args: dict[str, Any]) -> tuple[str, str]: |
| title = tool_display_title(tool_name, args) |
| if tool_name == "get_country_profile": |
| countries = _join(args.get("countries") or [], "selected countries") |
| return title, f"Looking up country metadata for {countries}." |
| if tool_name == "search_immigration_info": |
| query = truncate(str(args.get("query") or "immigration sources"), 180) |
| return title, f"Searching for official immigration information: {query}" |
| if tool_name == "scrape_web_page": |
| return title, f"Reading {args.get('url', 'the selected page')}." |
| if tool_name == "crawl_web_site": |
| return title, f"Crawling related pages from {args.get('url', 'the selected site')}." |
| if tool_name == "update_globe": |
| countries = _join(args.get("countries") or [], "the selected countries") |
| return title, f"Showing {countries} on the globe." |
| return title, f"Running `{tool_name}`." |
|
|
|
|
| def _format_log_result(result: str) -> Any: |
| loaded = _load_result(result) |
| serialized = json.dumps(loaded, default=str, indent=2) |
| if len(serialized) <= 1500: |
| return loaded |
| return serialized[:1500] + "\n… (truncated)" |
|
|
|
|
| def _tool_log_metadata( |
| tool_name: str, |
| parsed_args: dict[str, Any], |
| result: str, |
| ) -> dict[str, Any]: |
| return { |
| "tool": tool_name, |
| "arguments": parsed_args, |
| "result": _format_log_result(result), |
| } |
|
|
|
|
| def should_emit_reasoning( |
| reasoning: str, |
| tool_calls: list[dict[str, Any]] | None, |
| ) -> bool: |
| thought = reasoning.strip() |
| if not thought: |
| return False |
| if tool_calls and len(tool_calls) == 1: |
| function = tool_calls[0].get("function") or {} |
| if function.get("name") == "think": |
| return False |
| return True |
|
|
|
|
| def emit_thinking_message( |
| ui_messages: list[ChatMessage], |
| reasoning: str, |
| globe_state: dict[str, Any], |
| ): |
| thought = reasoning.strip() |
| if not thought: |
| return |
| ui_messages.append( |
| ChatMessage( |
| role="assistant", |
| content=thought, |
| metadata={ |
| "title": "Thinking", |
| "status": "done", |
| }, |
| ) |
| ) |
| yield ui_messages, globe_state |
|
|
|
|
| def _done_tool_message(tool_name: str, args: dict[str, Any], result: str) -> str: |
| parsed = _load_result(result) |
| if parsed.get("error"): |
| return f"Tool returned an issue: {parsed['error']}" |
|
|
| if tool_name == "get_country_profile": |
| countries = [ |
| country.get("name", "") |
| for country in parsed.get("countries", []) |
| if isinstance(country, dict) |
| ] |
| return f"Found country metadata for {_join(countries, 'the selected countries')}." |
|
|
| if tool_name == "search_immigration_info": |
| count = parsed.get("num_results", 0) |
| official = parsed.get("official_results", 0) |
| hints = parsed.get("official_domain_hints") or [] |
| hint_text = f" Suggested official domains: {_join(hints, 'none')}." if hints else "" |
| return f"Found {count} search results, including {official} likely official source(s).{hint_text}" |
|
|
| if tool_name == "scrape_web_page": |
| title = parsed.get("title") or args.get("url") or "the page" |
| source = parsed.get("source_url") or parsed.get("url") |
| return f"Extracted official page content from {title}. Source: {source}" |
|
|
| if tool_name == "crawl_web_site": |
| pages = parsed.get("pages_found", 0) |
| return f"Collected {pages} related page(s) from the official site." |
|
|
| if tool_name == "update_globe": |
| countries = [ |
| country.get("name", country.get("iso2", "")) |
| for country in parsed.get("countries", []) |
| if isinstance(country, dict) |
| ] |
| return f"Updated the globe with {_join(countries, 'the selected countries')}." |
|
|
| return f"`{tool_name}` completed." |
|
|
|
|
| def execute_tool_calls( |
| api_messages: list[dict[str, Any]], |
| ui_messages: list[ChatMessage], |
| tool_calls: list[dict[str, Any]], |
| content: str, |
| globe_state: dict[str, Any], |
| ): |
| tool_calls = _normalized_tool_calls(tool_calls) |
| api_messages.append(assistant_message_dict(content, tool_calls)) |
|
|
| for tool_call in tool_calls: |
| tool_name = tool_call["function"]["name"] |
| tool_args = tool_call["function"]["arguments"] |
| parsed_args = _parse_arguments(tool_args) |
| started = time.monotonic() |
|
|
| if tool_name == "think": |
| thought = str(parsed_args.get("thought") or "").strip() or "Planning next steps." |
| ui_messages.append( |
| ChatMessage( |
| role="assistant", |
| content=thought, |
| metadata={ |
| "title": "Thinking", |
| "status": "done", |
| }, |
| ) |
| ) |
| yield ui_messages, globe_state |
|
|
| result, globe_state = run_tool( |
| tool_name, |
| tool_args, |
| globe_state=globe_state, |
| ) |
| duration = time.monotonic() - started |
| record_tool_trace( |
| tool_name=tool_name, |
| arguments=tool_args, |
| result=result, |
| duration=duration, |
| ) |
| ui_messages[-1] = ChatMessage( |
| role="assistant", |
| content=thought, |
| metadata={ |
| "title": "Thinking", |
| "status": "done", |
| "duration": duration, |
| "log": _tool_log_metadata(tool_name, parsed_args, result), |
| }, |
| ) |
| yield ui_messages, globe_state |
|
|
| api_messages.append( |
| { |
| "role": "tool", |
| "tool_call_id": tool_call["id"], |
| "name": tool_name, |
| "content": result, |
| } |
| ) |
| continue |
|
|
| title, pending_message = _pending_tool_message(tool_name, parsed_args) |
| ui_messages.append( |
| ChatMessage( |
| role="assistant", |
| content=pending_message, |
| metadata={ |
| "title": title, |
| "status": "pending", |
| "log": { |
| "tool": tool_name, |
| "arguments": parsed_args, |
| }, |
| }, |
| ) |
| ) |
| yield ui_messages, globe_state |
|
|
| result, globe_state = run_tool( |
| tool_name, |
| tool_args, |
| globe_state=globe_state, |
| ) |
| duration = time.monotonic() - started |
| record_tool_trace( |
| tool_name=tool_name, |
| arguments=tool_args, |
| result=result, |
| duration=duration, |
| ) |
|
|
| ui_messages[-1] = ChatMessage( |
| role="assistant", |
| content=_done_tool_message(tool_name, parsed_args, result), |
| metadata={ |
| "title": title, |
| "status": "done", |
| "duration": duration, |
| "log": _tool_log_metadata(tool_name, parsed_args, result), |
| }, |
| ) |
| yield ui_messages, globe_state |
|
|
| api_messages.append( |
| { |
| "role": "tool", |
| "tool_call_id": tool_call["id"], |
| "name": tool_name, |
| "content": result, |
| } |
| ) |
|
|