KGNINJA commited on
Commit
8b4b273
·
verified ·
1 Parent(s): c469103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -150
app.py CHANGED
@@ -1,217 +1,212 @@
1
  import os
2
  import re
3
- from typing import Any, Dict, Tuple
 
4
 
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  from transformers import AutoProcessor, AutoModelForCausalLM
 
8
 
9
- # -------------------------
10
- # 1) Function call parser
11
- # -------------------------
12
- _ESCAPE = "<escape>"
13
- _CALL_START = "<start_function_call>"
14
- _CALL_END = "<end_function_call>"
15
-
16
- def _split_top_level_commas(s: str) -> list[str]:
17
- """
18
- Split "k1:v1,k2:<escape>v,2<escape>,k3:3" by commas, but ignore commas inside <escape> ... <escape>.
19
- """
20
- parts = []
21
- buf = []
 
 
 
 
22
  i = 0
23
- in_escape = False
24
  while i < len(s):
25
- if s.startswith(_ESCAPE, i):
26
- in_escape = not in_escape
27
- buf.append(_ESCAPE)
28
- i += len(_ESCAPE)
29
  continue
30
- ch = s[i]
31
- if ch == "," and not in_escape:
32
  parts.append("".join(buf).strip())
33
  buf = []
34
  else:
35
- buf.append(ch)
36
  i += 1
37
  if buf:
38
  parts.append("".join(buf).strip())
39
- return [p for p in parts if p]
40
-
41
- def _parse_value(raw: str) -> Any:
42
- raw = raw.strip()
43
- # string wrapped with <escape> ... <escape>
44
- if raw.startswith(_ESCAPE) and raw.endswith(_ESCAPE) and len(raw) >= 2 * len(_ESCAPE):
45
- return raw[len(_ESCAPE):-len(_ESCAPE)]
46
- # bool
47
- if raw.lower() in ("true", "false"):
48
- return raw.lower() == "true"
49
- # int / float
50
  try:
51
- if "." in raw:
52
- return float(raw)
53
- return int(raw)
54
  except ValueError:
55
- # fallback: plain string
56
- return raw
57
-
58
- def parse_function_call(text: str) -> Tuple[Dict[str, Any] | None, str]:
59
- """
60
- Returns (call, raw_text).
61
- call = {"name": "...", "arguments": {...}} if a function call exists, else None.
62
- """
63
- if _CALL_START not in text:
64
- return None, text.strip()
65
-
66
- # Grab the first function call block
67
- m = re.search(rf"{re.escape(_CALL_START)}(.*?){re.escape(_CALL_END)}", text, re.DOTALL)
68
  if not m:
69
- return None, text.strip()
70
 
71
- inside = m.group(1).strip() # ex: "call:move_robot{direction:<escape>forward<escape>,meters:1}"
72
- m2 = re.match(r"call:([A-Za-z0-9_\-]+)\{(.*)\}$", inside, re.DOTALL)
73
  if not m2:
74
- return None, text.strip()
75
 
76
  name = m2.group(1)
77
- args_blob = m2.group(2).strip()
78
 
79
- arguments: Dict[str, Any] = {}
80
- if args_blob:
81
- for kv in _split_top_level_commas(args_blob):
82
- if ":" not in kv:
83
- continue
84
- k, v = kv.split(":", 1)
85
- arguments[k.strip()] = _parse_value(v)
86
 
87
- return {"name": name, "arguments": arguments}, text.strip()
88
-
89
-
90
- # -------------------------
91
- # 2) FastAPI + Model
92
- # -------------------------
93
- GEMMA_MODEL_ID = os.getenv("GEMMA_MODEL_ID", "google/functiongemma-270m-it")
94
-
95
- app = FastAPI(title="FunctionGemma FastAPI Minimal")
96
-
97
- processor = None
98
- model = None
99
 
100
- # Tool schemas (simulation actions)
 
 
101
  TOOLS = [
102
  {
103
  "type": "function",
104
  "function": {
105
- "name": "move_robot",
106
- "description": "Move the robot in the simulator.",
107
  "parameters": {
108
  "type": "object",
109
  "properties": {
110
- "direction": {"type": "string", "description": "forward|backward|left|right"},
111
- "meters": {"type": "number", "description": "distance in meters"},
112
  },
113
- "required": ["direction", "meters"],
114
- },
115
- },
116
  },
117
  {
118
  "type": "function",
119
  "function": {
120
- "name": "turn_robot",
121
- "description": "Turn the robot in place in the simulator.",
122
  "parameters": {
123
  "type": "object",
124
  "properties": {
125
- "angle_deg": {"type": "number", "description": "positive=right, negative=left"},
126
  },
127
- "required": ["angle_deg"],
128
- },
129
- },
130
  },
131
  {
132
  "type": "function",
133
  "function": {
134
- "name": "speak",
135
- "description": "Make the robot speak (subtitle in the simulator).",
136
  "parameters": {
137
  "type": "object",
138
- "properties": {
139
- "text": {"type": "string", "description": "utterance"},
140
- },
141
- "required": ["text"],
142
- },
143
- },
144
- },
145
  ]
146
 
147
- class PlanRequest(BaseModel):
148
- prompt: str
149
- # simulator state/observation from browser (optional but recommended)
150
- observation: Dict[str, Any] | None = None
151
-
152
- class PlanResponse(BaseModel):
153
- tool_call: Dict[str, Any] | None = None
154
- raw_output: str
155
- note: str | None = None
156
-
157
- @app.on_event("startup")
158
- def _load():
159
  global processor, model
160
- # If model is gated, HF_TOKEN must be set in Space Secrets.
161
- processor = AutoProcessor.from_pretrained(GEMMA_MODEL_ID, device_map="auto")
162
- model = AutoModelForCausalLM.from_pretrained(GEMMA_MODEL_ID, dtype="auto", device_map="auto")
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  @app.get("/health")
165
  def health():
166
- return {"ok": True, "model": GEMMA_MODEL_ID}
167
-
168
- @app.post("/plan", response_model=PlanResponse)
169
- def plan(req: PlanRequest):
170
- """
171
- Browser sends: user prompt + observation (e.g., lidar-like distances, detected labels, etc.)
172
- Server returns: one tool call (move/turn/speak) as FunctionGemma formatted output parsed into JSON.
173
- """
174
- # Essential developer instruction is required to activate function calling behavior. :contentReference[oaicite:4]{index=4}
 
 
 
 
 
 
 
 
 
175
  messages = [
176
- {"role": "developer", "content": "You are a model that can do function calling with the following functions"},
177
- {"role": "user", "content": _build_user_content(req.prompt, req.observation)},
 
178
  ]
179
 
180
  inputs = processor.apply_chat_template(
181
  messages,
182
  tools=TOOLS,
183
  add_generation_prompt=True,
184
- return_dict=True,
185
- return_tensors="pt",
186
  )
187
 
188
- out = model.generate(
189
  **inputs.to(model.device),
190
- pad_token_id=processor.eos_token_id,
191
  max_new_tokens=128,
 
192
  )
193
 
194
- # decode only newly generated tokens
195
- gen = out[0][len(inputs["input_ids"][0]):]
196
- raw = processor.decode(gen, skip_special_tokens=True)
197
-
198
- tool_call, _ = parse_function_call(raw)
199
-
200
- note = None
201
- if tool_call is None:
202
- note = "No structured function call produced. Consider enriching tool descriptions or simplifying the prompt."
203
-
204
- return PlanResponse(tool_call=tool_call, raw_output=raw, note=note)
205
-
206
- def _build_user_content(prompt: str, observation: Dict[str, Any] | None) -> str:
207
- # Keep it single-step: decide ONLY the next action.
208
- # FunctionGemma is strongest in single-turn / parallel calls, not multi-step chaining. :contentReference[oaicite:5]{index=5}
209
- obs_txt = ""
210
- if observation:
211
- obs_txt = f"\n\n[OBSERVATION]\n{observation}\n"
212
- return (
213
- f"{prompt}\n"
214
- f"{obs_txt}\n"
215
- "Decide ONLY the next simulator action. "
216
- "Use exactly one function call from the provided tools."
217
  )
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
+ from typing import Any, Dict, Optional
4
+ from contextlib import asynccontextmanager
5
 
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
  from transformers import AutoProcessor, AutoModelForCausalLM
9
+ import torch
10
 
11
+ # =========================
12
+ # Configuration
13
+ # =========================
14
+ MODEL_ID = os.getenv("GEMMA_MODEL_ID", "google/functiongemma-270m-it")
15
+
16
+ processor = None
17
+ model = None
18
+
19
+ # =========================
20
+ # Function Call Parser
21
+ # =========================
22
+ ESC = "<escape>"
23
+ START = "<start_function_call>"
24
+ END = "<end_function_call>"
25
+
26
+ def _split_commas(s: str):
27
+ parts, buf, esc = [], [], False
28
  i = 0
 
29
  while i < len(s):
30
+ if s.startswith(ESC, i):
31
+ esc = not esc
32
+ buf.append(ESC)
33
+ i += len(ESC)
34
  continue
35
+ if s[i] == "," and not esc:
 
36
  parts.append("".join(buf).strip())
37
  buf = []
38
  else:
39
+ buf.append(s[i])
40
  i += 1
41
  if buf:
42
  parts.append("".join(buf).strip())
43
+ return parts
44
+
45
+ def _parse_value(v: str):
46
+ v = v.strip()
47
+ if v.startswith(ESC) and v.endswith(ESC):
48
+ return v[len(ESC):-len(ESC)]
49
+ if v.lower() in ("true", "false"):
50
+ return v.lower() == "true"
 
 
 
51
  try:
52
+ if "." in v:
53
+ return float(v)
54
+ return int(v)
55
  except ValueError:
56
+ return v
57
+
58
+ def parse_function_call(text: str):
59
+ if START not in text:
60
+ return None
61
+
62
+ m = re.search(rf"{START}(.*?){END}", text, re.DOTALL)
 
 
 
 
 
 
63
  if not m:
64
+ return None
65
 
66
+ body = m.group(1).strip()
67
+ m2 = re.match(r"call:([a-zA-Z0-9_]+)\{(.*)\}$", body, re.DOTALL)
68
  if not m2:
69
+ return None
70
 
71
  name = m2.group(1)
72
+ args_raw = m2.group(2).strip()
73
 
74
+ args = {}
75
+ if args_raw:
76
+ for kv in _split_commas(args_raw):
77
+ if ":" in kv:
78
+ k, v = kv.split(":", 1)
79
+ args[k.strip()] = _parse_value(v)
 
80
 
81
+ return {"name": name, "arguments": args}
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # =========================
84
+ # Tools (Robot Actions)
85
+ # =========================
86
  TOOLS = [
87
  {
88
  "type": "function",
89
  "function": {
90
+ "name": "move",
91
+ "description": "Move forward or backward",
92
  "parameters": {
93
  "type": "object",
94
  "properties": {
95
+ "direction": {"type": "string"},
96
+ "speed": {"type": "number"}
97
  },
98
+ "required": ["direction", "speed"]
99
+ }
100
+ }
101
  },
102
  {
103
  "type": "function",
104
  "function": {
105
+ "name": "turn",
106
+ "description": "Rotate left or right",
107
  "parameters": {
108
  "type": "object",
109
  "properties": {
110
+ "angle": {"type": "number"}
111
  },
112
+ "required": ["angle"]
113
+ }
114
+ }
115
  },
116
  {
117
  "type": "function",
118
  "function": {
119
+ "name": "pause",
120
+ "description": "Stop and observe",
121
  "parameters": {
122
  "type": "object",
123
+ "properties": {}
124
+ }
125
+ }
126
+ }
 
 
 
127
  ]
128
 
129
+ # =========================
130
+ # FastAPI Lifespan
131
+ # =========================
132
+ @asynccontextmanager
133
+ async def lifespan(app: FastAPI):
 
 
 
 
 
 
 
134
  global processor, model
135
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
136
+ model = AutoModelForCausalLM.from_pretrained(
137
+ MODEL_ID,
138
+ torch_dtype=torch.float32,
139
+ device_map="auto"
140
+ )
141
+ yield
142
+ # shutdown処理(今回は不要)
143
+
144
+ app = FastAPI(
145
+ title="FunctionGemma Robot Brain",
146
+ lifespan=lifespan
147
+ )
148
+
149
+ # =========================
150
+ # API Schema
151
+ # =========================
152
+ class DecideRequest(BaseModel):
153
+ observation: Dict[str, Any]
154
+ persona: Optional[str] = "curious"
155
+
156
+ class DecideResponse(BaseModel):
157
+ action: Optional[Dict[str, Any]]
158
+ raw: str
159
+
160
+ # =========================
161
+ # Endpoints
162
+ # =========================
163
  @app.get("/health")
164
  def health():
165
+ return {"status": "ok", "model": MODEL_ID}
166
+
167
+ @app.post("/decide", response_model=DecideResponse)
168
+ def decide(req: DecideRequest):
169
+ system = (
170
+ "You are a small exploration robot.\n"
171
+ "You must choose exactly ONE action function.\n"
172
+ "You are curious but avoid real danger.\n"
173
+ f"Persona: {req.persona}"
174
+ )
175
+
176
+ user = f"""
177
+ Observation:
178
+ {req.observation}
179
+
180
+ Choose the next action.
181
+ """
182
+
183
  messages = [
184
+ {"role": "developer", "content": "You can call functions."},
185
+ {"role": "system", "content": system},
186
+ {"role": "user", "content": user}
187
  ]
188
 
189
  inputs = processor.apply_chat_template(
190
  messages,
191
  tools=TOOLS,
192
  add_generation_prompt=True,
193
+ return_tensors="pt"
 
194
  )
195
 
196
+ outputs = model.generate(
197
  **inputs.to(model.device),
 
198
  max_new_tokens=128,
199
+ pad_token_id=processor.eos_token_id
200
  )
201
 
202
+ decoded = processor.decode(
203
+ outputs[0][inputs["input_ids"].shape[-1]:],
204
+ skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  )
206
+
207
+ action = parse_function_call(decoded)
208
+
209
+ return {
210
+ "action": action,
211
+ "raw": decoded
212
+ }