File size: 6,504 Bytes
ed9acde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from __future__ import annotations

import asyncio
import json
import os
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, List

from google import genai


# ----------------------------
# Session state
# ----------------------------

@dataclass
class ToolCallAwaiter:
    fut: asyncio.Future


@dataclass
class SessionState:
    history: List[dict] = field(default_factory=list)
    # name -> schema dict (Scratch-provided, you control the format)
    functions: Dict[str, dict] = field(default_factory=dict)
    # call_id -> awaiter
    pending_calls: Dict[str, ToolCallAwaiter] = field(default_factory=dict)


SESSIONS: Dict[str, SessionState] = {}


def get_session(session_id: str) -> SessionState:
    if session_id not in SESSIONS:
        SESSIONS[session_id] = SessionState()
    return SESSIONS[session_id]


# ----------------------------
# Gemini client
# ----------------------------

def _get_genai_client() -> genai.Client:
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        raise RuntimeError("Missing GEMINI_API_KEY env var.")
    return genai.Client(api_key=api_key)


def _scratch_schema_to_gemini_decl(name: str, schema: dict) -> dict:
    """
    Convert a Scratch-side function schema into a Gemini-compatible function declaration.

    Expected Scratch schema (example):
    {
      "description": "Open the settings page",
      "parameters": {
        "type": "object",
        "properties": {
          "tab": {"type":"string", "description":"Which tab to open"}
        },
        "required": ["tab"]
      }
    }
    """
    desc = (schema or {}).get("description", "")
    params = (schema or {}).get("parameters") or {"type": "object", "properties": {}}

    return {
        "name": name,
        "description": desc,
        "parameters": params,
    }


async def gemini_chat_turn(
    *,
    session_id: str,
    user_text: str,
    emit_event,  # async fn(dict) -> None  (send to ws client)
    model: str = "gemini-2.0-flash",
) -> str:
    """
    Sends one user turn to Gemini Flash (text), supports tool calling by bouncing tool calls to the WS client.
    """
    s = get_session(session_id)
    client = _get_genai_client()

    # Build tool declarations from session functions
    tool_decls = []
    for fname, fschema in s.functions.items():
        tool_decls.append(_scratch_schema_to_gemini_decl(fname, fschema))

    # Build content. Keep it simple + stable.
    # Note: google-genai accepts "contents" as a list of role/content dicts.
    s.history.append({"role": "user", "parts": [{"text": user_text}]})

    # We run a loop because Gemini might call tools then continue.
    while True:
        resp = client.models.generate_content(
            model=model,
            contents=s.history,
            config={
                "tools": [{"function_declarations": tool_decls}] if tool_decls else None,
                # Keep responses short-ish for Scratch club usage
                "temperature": 0.6,
            },
        )

        # google-genai response parsing varies across versions; handle robustly:
        # We look for:
        # - normal text in resp.candidates[].content.parts[].text
        # - tool call in resp.candidates[].content.parts[].function_call
        cand = (getattr(resp, "candidates", None) or [None])[0]
        content = getattr(cand, "content", None) if cand else None
        parts = getattr(content, "parts", None) if content else None
        parts = parts or []

        # Extract tool calls + text chunks
        tool_calls = []
        text_chunks = []

        for p in parts:
            fc = getattr(p, "function_call", None)
            tx = getattr(p, "text", None)

            if tx:
                text_chunks.append(tx)

            if fc:
                # fc has name + args
                name = getattr(fc, "name", None)
                args = getattr(fc, "args", None)
                if isinstance(args, str):
                    try:
                        args = json.loads(args)
                    except Exception:
                        args = {"_raw": args}
                tool_calls.append({"name": name, "args": args or {}})

        # If we got text and no tools, we’re done
        if text_chunks and not tool_calls:
            assistant_text = "".join(text_chunks).strip()
            s.history.append({"role": "model", "parts": [{"text": assistant_text}]})
            return assistant_text

        # If tools were requested, execute via WS client
        if tool_calls:
            for tc in tool_calls:
                fname = tc["name"] or "unknown_function"
                fargs = tc["args"] or {}

                call_id = str(uuid.uuid4())
                fut = asyncio.get_event_loop().create_future()
                s.pending_calls[call_id] = ToolCallAwaiter(fut=fut)

                await emit_event(
                    {
                        "type": "function_called",
                        "call_id": call_id,
                        "name": fname,
                        "arguments": fargs,
                    }
                )

                # Wait for Scratch to respond with function_result
                result = await fut

                # Add the tool result back to Gemini’s history
                # Tool response format: role "tool" with function_response part.
                s.history.append(
                    {
                        "role": "tool",
                        "parts": [
                            {
                                "function_response": {
                                    "name": fname,
                                    "response": {"result": result},
                                }
                            }
                        ],
                    }
                )

            # Loop continues to let Gemini produce final text after tools

        # If no text and no tool calls, fallback
        if not text_chunks and not tool_calls:
            assistant_text = "(No response.)"
            s.history.append({"role": "model", "parts": [{"text": assistant_text}]})
            return assistant_text


def deliver_function_result(session_id: str, call_id: str, result: Any) -> bool:
    s = get_session(session_id)
    aw = s.pending_calls.get(call_id)
    if not aw:
        return False
    if not aw.fut.done():
        aw.fut.set_result(result)
    s.pending_calls.pop(call_id, None)
    return True