| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import os |
| import uuid |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, Optional, List |
|
|
| from google import genai |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ToolCallAwaiter: |
| fut: asyncio.Future |
|
|
|
|
| @dataclass |
| class SessionState: |
| history: List[dict] = field(default_factory=list) |
| |
| functions: Dict[str, dict] = field(default_factory=dict) |
| |
| pending_calls: Dict[str, ToolCallAwaiter] = field(default_factory=dict) |
|
|
|
|
| SESSIONS: Dict[str, SessionState] = {} |
|
|
|
|
| def get_session(session_id: str) -> SessionState: |
| if session_id not in SESSIONS: |
| SESSIONS[session_id] = SessionState() |
| return SESSIONS[session_id] |
|
|
|
|
| |
| |
| |
|
|
| def _get_genai_client() -> genai.Client: |
| api_key = os.getenv("GEMINI_API_KEY") |
| if not api_key: |
| raise RuntimeError("Missing GEMINI_API_KEY env var.") |
| return genai.Client(api_key=api_key) |
|
|
|
|
| def _scratch_schema_to_gemini_decl(name: str, schema: dict) -> dict: |
| """ |
| Convert a Scratch-side function schema into a Gemini-compatible function declaration. |
| |
| Expected Scratch schema (example): |
| { |
| "description": "Open the settings page", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "tab": {"type":"string", "description":"Which tab to open"} |
| }, |
| "required": ["tab"] |
| } |
| } |
| """ |
| desc = (schema or {}).get("description", "") |
| params = (schema or {}).get("parameters") or {"type": "object", "properties": {}} |
|
|
| return { |
| "name": name, |
| "description": desc, |
| "parameters": params, |
| } |
|
|
|
|
| async def gemini_chat_turn( |
| *, |
| session_id: str, |
| user_text: str, |
| emit_event, |
| model: str = "gemini-2.0-flash", |
| ) -> str: |
| """ |
| Sends one user turn to Gemini Flash (text), supports tool calling by bouncing tool calls to the WS client. |
| """ |
| s = get_session(session_id) |
| client = _get_genai_client() |
|
|
| |
| tool_decls = [] |
| for fname, fschema in s.functions.items(): |
| tool_decls.append(_scratch_schema_to_gemini_decl(fname, fschema)) |
|
|
| |
| |
| s.history.append({"role": "user", "parts": [{"text": user_text}]}) |
|
|
| |
| while True: |
| resp = client.models.generate_content( |
| model=model, |
| contents=s.history, |
| config={ |
| "tools": [{"function_declarations": tool_decls}] if tool_decls else None, |
| |
| "temperature": 0.6, |
| }, |
| ) |
|
|
| |
| |
| |
| |
| cand = (getattr(resp, "candidates", None) or [None])[0] |
| content = getattr(cand, "content", None) if cand else None |
| parts = getattr(content, "parts", None) if content else None |
| parts = parts or [] |
|
|
| |
| tool_calls = [] |
| text_chunks = [] |
|
|
| for p in parts: |
| fc = getattr(p, "function_call", None) |
| tx = getattr(p, "text", None) |
|
|
| if tx: |
| text_chunks.append(tx) |
|
|
| if fc: |
| |
| name = getattr(fc, "name", None) |
| args = getattr(fc, "args", None) |
| if isinstance(args, str): |
| try: |
| args = json.loads(args) |
| except Exception: |
| args = {"_raw": args} |
| tool_calls.append({"name": name, "args": args or {}}) |
|
|
| |
| if text_chunks and not tool_calls: |
| assistant_text = "".join(text_chunks).strip() |
| s.history.append({"role": "model", "parts": [{"text": assistant_text}]}) |
| return assistant_text |
|
|
| |
| if tool_calls: |
| for tc in tool_calls: |
| fname = tc["name"] or "unknown_function" |
| fargs = tc["args"] or {} |
|
|
| call_id = str(uuid.uuid4()) |
| fut = asyncio.get_event_loop().create_future() |
| s.pending_calls[call_id] = ToolCallAwaiter(fut=fut) |
|
|
| await emit_event( |
| { |
| "type": "function_called", |
| "call_id": call_id, |
| "name": fname, |
| "arguments": fargs, |
| } |
| ) |
|
|
| |
| result = await fut |
|
|
| |
| |
| s.history.append( |
| { |
| "role": "tool", |
| "parts": [ |
| { |
| "function_response": { |
| "name": fname, |
| "response": {"result": result}, |
| } |
| } |
| ], |
| } |
| ) |
|
|
| |
|
|
| |
| if not text_chunks and not tool_calls: |
| assistant_text = "(No response.)" |
| s.history.append({"role": "model", "parts": [{"text": assistant_text}]}) |
| return assistant_text |
|
|
|
|
| def deliver_function_result(session_id: str, call_id: str, result: Any) -> bool: |
| s = get_session(session_id) |
| aw = s.pending_calls.get(call_id) |
| if not aw: |
| return False |
| if not aw.fut.done(): |
| aw.fut.set_result(result) |
| s.pending_calls.pop(call_id, None) |
| return True |
|
|