Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| from typing import Any, Dict, Optional | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| import torch | |
| # ========================= | |
| # Configuration | |
| # ========================= | |
| MODEL_ID = os.getenv("GEMMA_MODEL_ID", "google/functiongemma-270m-it") | |
| processor = None | |
| model = None | |
| # ========================= | |
| # Function Call Parser | |
| # ========================= | |
| ESC = "<escape>" | |
| START = "<start_function_call>" | |
| END = "<end_function_call>" | |
| def _split_commas(s: str): | |
| parts, buf, esc = [], [], False | |
| i = 0 | |
| while i < len(s): | |
| if s.startswith(ESC, i): | |
| esc = not esc | |
| buf.append(ESC) | |
| i += len(ESC) | |
| continue | |
| if s[i] == "," and not esc: | |
| parts.append("".join(buf).strip()) | |
| buf = [] | |
| else: | |
| buf.append(s[i]) | |
| i += 1 | |
| if buf: | |
| parts.append("".join(buf).strip()) | |
| return parts | |
| def _parse_value(v: str): | |
| v = v.strip() | |
| if v.startswith(ESC) and v.endswith(ESC): | |
| return v[len(ESC):-len(ESC)] | |
| if v.lower() in ("true", "false"): | |
| return v.lower() == "true" | |
| try: | |
| if "." in v: | |
| return float(v) | |
| return int(v) | |
| except ValueError: | |
| return v | |
| def parse_function_call(text: str): | |
| if START not in text: | |
| return None | |
| m = re.search(rf"{START}(.*?){END}", text, re.DOTALL) | |
| if not m: | |
| return None | |
| body = m.group(1).strip() | |
| m2 = re.match(r"call:([a-zA-Z0-9_]+)\{(.*)\}$", body, re.DOTALL) | |
| if not m2: | |
| return None | |
| name = m2.group(1) | |
| args_raw = m2.group(2).strip() | |
| args = {} | |
| if args_raw: | |
| for kv in _split_commas(args_raw): | |
| if ":" in kv: | |
| k, v = kv.split(":", 1) | |
| args[k.strip()] = _parse_value(v) | |
| return {"name": name, "arguments": args} | |
| # ========================= | |
| # Tools (Robot Actions) | |
| # ========================= | |
| TOOLS = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "move", | |
| "description": "Move forward or backward", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "direction": {"type": "string"}, | |
| "speed": {"type": "number"} | |
| }, | |
| "required": ["direction", "speed"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "turn", | |
| "description": "Rotate left or right", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "angle": {"type": "number"} | |
| }, | |
| "required": ["angle"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "pause", | |
| "description": "Stop and observe", | |
| "parameters": { | |
| "type": "object", | |
| "properties": {} | |
| } | |
| } | |
| } | |
| ] | |
| # ========================= | |
| # FastAPI Lifespan | |
| # ========================= | |
| async def lifespan(app: FastAPI): | |
| global processor, model | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| device_map="auto" | |
| ) | |
| yield | |
| # shutdown処理(今回は不要) | |
| app = FastAPI( | |
| title="FunctionGemma Robot Brain", | |
| lifespan=lifespan | |
| ) | |
| # ========================= | |
| # API Schema | |
| # ========================= | |
| class DecideRequest(BaseModel): | |
| observation: Dict[str, Any] | |
| persona: Optional[str] = "curious" | |
| class DecideResponse(BaseModel): | |
| action: Optional[Dict[str, Any]] | |
| raw: str | |
| # ========================= | |
| # Endpoints | |
| # ========================= | |
| def health(): | |
| return {"status": "ok", "model": MODEL_ID} | |
| def decide(req: DecideRequest): | |
| system = ( | |
| "You are a small exploration robot.\n" | |
| "You must choose exactly ONE action function.\n" | |
| "You are curious but avoid real danger.\n" | |
| f"Persona: {req.persona}" | |
| ) | |
| user = f""" | |
| Observation: | |
| {req.observation} | |
| Choose the next action. | |
| """ | |
| messages = [ | |
| {"role": "developer", "content": "You can call functions."}, | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user} | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tools=TOOLS, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| outputs = model.generate( | |
| **inputs.to(model.device), | |
| max_new_tokens=128, | |
| pad_token_id=processor.eos_token_id | |
| ) | |
| decoded = processor.decode( | |
| outputs[0][inputs["input_ids"].shape[-1]:], | |
| skip_special_tokens=True | |
| ) | |
| action = parse_function_call(decoded) | |
| return { | |
| "action": action, | |
| "raw": decoded | |
| } | |