File size: 6,504 Bytes
ed9acde | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | from __future__ import annotations
import asyncio
import json
import os
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, List
from google import genai
# ----------------------------
# Session state
# ----------------------------
@dataclass
class ToolCallAwaiter:
fut: asyncio.Future
@dataclass
class SessionState:
history: List[dict] = field(default_factory=list)
# name -> schema dict (Scratch-provided, you control the format)
functions: Dict[str, dict] = field(default_factory=dict)
# call_id -> awaiter
pending_calls: Dict[str, ToolCallAwaiter] = field(default_factory=dict)
SESSIONS: Dict[str, SessionState] = {}
def get_session(session_id: str) -> SessionState:
if session_id not in SESSIONS:
SESSIONS[session_id] = SessionState()
return SESSIONS[session_id]
# ----------------------------
# Gemini client
# ----------------------------
def _get_genai_client() -> genai.Client:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise RuntimeError("Missing GEMINI_API_KEY env var.")
return genai.Client(api_key=api_key)
def _scratch_schema_to_gemini_decl(name: str, schema: dict) -> dict:
"""
Convert a Scratch-side function schema into a Gemini-compatible function declaration.
Expected Scratch schema (example):
{
"description": "Open the settings page",
"parameters": {
"type": "object",
"properties": {
"tab": {"type":"string", "description":"Which tab to open"}
},
"required": ["tab"]
}
}
"""
desc = (schema or {}).get("description", "")
params = (schema or {}).get("parameters") or {"type": "object", "properties": {}}
return {
"name": name,
"description": desc,
"parameters": params,
}
async def gemini_chat_turn(
*,
session_id: str,
user_text: str,
emit_event, # async fn(dict) -> None (send to ws client)
model: str = "gemini-2.0-flash",
) -> str:
"""
Sends one user turn to Gemini Flash (text), supports tool calling by bouncing tool calls to the WS client.
"""
s = get_session(session_id)
client = _get_genai_client()
# Build tool declarations from session functions
tool_decls = []
for fname, fschema in s.functions.items():
tool_decls.append(_scratch_schema_to_gemini_decl(fname, fschema))
# Build content. Keep it simple + stable.
# Note: google-genai accepts "contents" as a list of role/content dicts.
s.history.append({"role": "user", "parts": [{"text": user_text}]})
# We run a loop because Gemini might call tools then continue.
while True:
resp = client.models.generate_content(
model=model,
contents=s.history,
config={
"tools": [{"function_declarations": tool_decls}] if tool_decls else None,
# Keep responses short-ish for Scratch club usage
"temperature": 0.6,
},
)
# google-genai response parsing varies across versions; handle robustly:
# We look for:
# - normal text in resp.candidates[].content.parts[].text
# - tool call in resp.candidates[].content.parts[].function_call
cand = (getattr(resp, "candidates", None) or [None])[0]
content = getattr(cand, "content", None) if cand else None
parts = getattr(content, "parts", None) if content else None
parts = parts or []
# Extract tool calls + text chunks
tool_calls = []
text_chunks = []
for p in parts:
fc = getattr(p, "function_call", None)
tx = getattr(p, "text", None)
if tx:
text_chunks.append(tx)
if fc:
# fc has name + args
name = getattr(fc, "name", None)
args = getattr(fc, "args", None)
if isinstance(args, str):
try:
args = json.loads(args)
except Exception:
args = {"_raw": args}
tool_calls.append({"name": name, "args": args or {}})
# If we got text and no tools, we’re done
if text_chunks and not tool_calls:
assistant_text = "".join(text_chunks).strip()
s.history.append({"role": "model", "parts": [{"text": assistant_text}]})
return assistant_text
# If tools were requested, execute via WS client
if tool_calls:
for tc in tool_calls:
fname = tc["name"] or "unknown_function"
fargs = tc["args"] or {}
call_id = str(uuid.uuid4())
fut = asyncio.get_event_loop().create_future()
s.pending_calls[call_id] = ToolCallAwaiter(fut=fut)
await emit_event(
{
"type": "function_called",
"call_id": call_id,
"name": fname,
"arguments": fargs,
}
)
# Wait for Scratch to respond with function_result
result = await fut
# Add the tool result back to Gemini’s history
# Tool response format: role "tool" with function_response part.
s.history.append(
{
"role": "tool",
"parts": [
{
"function_response": {
"name": fname,
"response": {"result": result},
}
}
],
}
)
# Loop continues to let Gemini produce final text after tools
# If no text and no tool calls, fallback
if not text_chunks and not tool_calls:
assistant_text = "(No response.)"
s.history.append({"role": "model", "parts": [{"text": assistant_text}]})
return assistant_text
def deliver_function_result(session_id: str, call_id: str, result: Any) -> bool:
s = get_session(session_id)
aw = s.pending_calls.get(call_id)
if not aw:
return False
if not aw.fut.done():
aw.fut.set_result(result)
s.pending_calls.pop(call_id, None)
return True
|