SalexAI commited on
Commit
ed9acde
·
verified ·
1 Parent(s): dccec3c

Create gemini_text.py

Browse files
Files changed (1) hide show
  1. app/gemini_text.py +203 -0
app/gemini_text.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import os
6
+ import uuid
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, Optional, List
9
+
10
+ from google import genai
11
+
12
+
13
+ # ----------------------------
14
+ # Session state
15
+ # ----------------------------
16
+
17
+ @dataclass
18
+ class ToolCallAwaiter:
19
+ fut: asyncio.Future
20
+
21
+
22
+ @dataclass
23
+ class SessionState:
24
+ history: List[dict] = field(default_factory=list)
25
+ # name -> schema dict (Scratch-provided, you control the format)
26
+ functions: Dict[str, dict] = field(default_factory=dict)
27
+ # call_id -> awaiter
28
+ pending_calls: Dict[str, ToolCallAwaiter] = field(default_factory=dict)
29
+
30
+
31
+ SESSIONS: Dict[str, SessionState] = {}
32
+
33
+
34
+ def get_session(session_id: str) -> SessionState:
35
+ if session_id not in SESSIONS:
36
+ SESSIONS[session_id] = SessionState()
37
+ return SESSIONS[session_id]
38
+
39
+
40
+ # ----------------------------
41
+ # Gemini client
42
+ # ----------------------------
43
+
44
+ def _get_genai_client() -> genai.Client:
45
+ api_key = os.getenv("GEMINI_API_KEY")
46
+ if not api_key:
47
+ raise RuntimeError("Missing GEMINI_API_KEY env var.")
48
+ return genai.Client(api_key=api_key)
49
+
50
+
51
+ def _scratch_schema_to_gemini_decl(name: str, schema: dict) -> dict:
52
+ """
53
+ Convert a Scratch-side function schema into a Gemini-compatible function declaration.
54
+
55
+ Expected Scratch schema (example):
56
+ {
57
+ "description": "Open the settings page",
58
+ "parameters": {
59
+ "type": "object",
60
+ "properties": {
61
+ "tab": {"type":"string", "description":"Which tab to open"}
62
+ },
63
+ "required": ["tab"]
64
+ }
65
+ }
66
+ """
67
+ desc = (schema or {}).get("description", "")
68
+ params = (schema or {}).get("parameters") or {"type": "object", "properties": {}}
69
+
70
+ return {
71
+ "name": name,
72
+ "description": desc,
73
+ "parameters": params,
74
+ }
75
+
76
+
77
+ async def gemini_chat_turn(
78
+ *,
79
+ session_id: str,
80
+ user_text: str,
81
+ emit_event, # async fn(dict) -> None (send to ws client)
82
+ model: str = "gemini-2.0-flash",
83
+ ) -> str:
84
+ """
85
+ Sends one user turn to Gemini Flash (text), supports tool calling by bouncing tool calls to the WS client.
86
+ """
87
+ s = get_session(session_id)
88
+ client = _get_genai_client()
89
+
90
+ # Build tool declarations from session functions
91
+ tool_decls = []
92
+ for fname, fschema in s.functions.items():
93
+ tool_decls.append(_scratch_schema_to_gemini_decl(fname, fschema))
94
+
95
+ # Build content. Keep it simple + stable.
96
+ # Note: google-genai accepts "contents" as a list of role/content dicts.
97
+ s.history.append({"role": "user", "parts": [{"text": user_text}]})
98
+
99
+ # We run a loop because Gemini might call tools then continue.
100
+ while True:
101
+ resp = client.models.generate_content(
102
+ model=model,
103
+ contents=s.history,
104
+ config={
105
+ "tools": [{"function_declarations": tool_decls}] if tool_decls else None,
106
+ # Keep responses short-ish for Scratch club usage
107
+ "temperature": 0.6,
108
+ },
109
+ )
110
+
111
+ # google-genai response parsing varies across versions; handle robustly:
112
+ # We look for:
113
+ # - normal text in resp.candidates[].content.parts[].text
114
+ # - tool call in resp.candidates[].content.parts[].function_call
115
+ cand = (getattr(resp, "candidates", None) or [None])[0]
116
+ content = getattr(cand, "content", None) if cand else None
117
+ parts = getattr(content, "parts", None) if content else None
118
+ parts = parts or []
119
+
120
+ # Extract tool calls + text chunks
121
+ tool_calls = []
122
+ text_chunks = []
123
+
124
+ for p in parts:
125
+ fc = getattr(p, "function_call", None)
126
+ tx = getattr(p, "text", None)
127
+
128
+ if tx:
129
+ text_chunks.append(tx)
130
+
131
+ if fc:
132
+ # fc has name + args
133
+ name = getattr(fc, "name", None)
134
+ args = getattr(fc, "args", None)
135
+ if isinstance(args, str):
136
+ try:
137
+ args = json.loads(args)
138
+ except Exception:
139
+ args = {"_raw": args}
140
+ tool_calls.append({"name": name, "args": args or {}})
141
+
142
+ # If we got text and no tools, we’re done
143
+ if text_chunks and not tool_calls:
144
+ assistant_text = "".join(text_chunks).strip()
145
+ s.history.append({"role": "model", "parts": [{"text": assistant_text}]})
146
+ return assistant_text
147
+
148
+ # If tools were requested, execute via WS client
149
+ if tool_calls:
150
+ for tc in tool_calls:
151
+ fname = tc["name"] or "unknown_function"
152
+ fargs = tc["args"] or {}
153
+
154
+ call_id = str(uuid.uuid4())
155
+ fut = asyncio.get_event_loop().create_future()
156
+ s.pending_calls[call_id] = ToolCallAwaiter(fut=fut)
157
+
158
+ await emit_event(
159
+ {
160
+ "type": "function_called",
161
+ "call_id": call_id,
162
+ "name": fname,
163
+ "arguments": fargs,
164
+ }
165
+ )
166
+
167
+ # Wait for Scratch to respond with function_result
168
+ result = await fut
169
+
170
+ # Add the tool result back to Gemini’s history
171
+ # Tool response format: role "tool" with function_response part.
172
+ s.history.append(
173
+ {
174
+ "role": "tool",
175
+ "parts": [
176
+ {
177
+ "function_response": {
178
+ "name": fname,
179
+ "response": {"result": result},
180
+ }
181
+ }
182
+ ],
183
+ }
184
+ )
185
+
186
+ # Loop continues to let Gemini produce final text after tools
187
+
188
+ # If no text and no tool calls, fallback
189
+ if not text_chunks and not tool_calls:
190
+ assistant_text = "(No response.)"
191
+ s.history.append({"role": "model", "parts": [{"text": assistant_text}]})
192
+ return assistant_text
193
+
194
+
195
+ def deliver_function_result(session_id: str, call_id: str, result: Any) -> bool:
196
+ s = get_session(session_id)
197
+ aw = s.pending_calls.get(call_id)
198
+ if not aw:
199
+ return False
200
+ if not aw.fut.done():
201
+ aw.fut.set_result(result)
202
+ s.pending_calls.pop(call_id, None)
203
+ return True