KGNINJA's picture
Update app.py
8b4b273 verified
import os
import re
from typing import Any, Dict, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
# =========================
# Configuration
# =========================
MODEL_ID = os.getenv("GEMMA_MODEL_ID", "google/functiongemma-270m-it")
processor = None
model = None
# =========================
# Function Call Parser
# =========================
ESC = "<escape>"
START = "<start_function_call>"
END = "<end_function_call>"
def _split_commas(s: str):
parts, buf, esc = [], [], False
i = 0
while i < len(s):
if s.startswith(ESC, i):
esc = not esc
buf.append(ESC)
i += len(ESC)
continue
if s[i] == "," and not esc:
parts.append("".join(buf).strip())
buf = []
else:
buf.append(s[i])
i += 1
if buf:
parts.append("".join(buf).strip())
return parts
def _parse_value(v: str):
v = v.strip()
if v.startswith(ESC) and v.endswith(ESC):
return v[len(ESC):-len(ESC)]
if v.lower() in ("true", "false"):
return v.lower() == "true"
try:
if "." in v:
return float(v)
return int(v)
except ValueError:
return v
def parse_function_call(text: str):
if START not in text:
return None
m = re.search(rf"{START}(.*?){END}", text, re.DOTALL)
if not m:
return None
body = m.group(1).strip()
m2 = re.match(r"call:([a-zA-Z0-9_]+)\{(.*)\}$", body, re.DOTALL)
if not m2:
return None
name = m2.group(1)
args_raw = m2.group(2).strip()
args = {}
if args_raw:
for kv in _split_commas(args_raw):
if ":" in kv:
k, v = kv.split(":", 1)
args[k.strip()] = _parse_value(v)
return {"name": name, "arguments": args}
# =========================
# Tools (Robot Actions)
# =========================
TOOLS = [
{
"type": "function",
"function": {
"name": "move",
"description": "Move forward or backward",
"parameters": {
"type": "object",
"properties": {
"direction": {"type": "string"},
"speed": {"type": "number"}
},
"required": ["direction", "speed"]
}
}
},
{
"type": "function",
"function": {
"name": "turn",
"description": "Rotate left or right",
"parameters": {
"type": "object",
"properties": {
"angle": {"type": "number"}
},
"required": ["angle"]
}
}
},
{
"type": "function",
"function": {
"name": "pause",
"description": "Stop and observe",
"parameters": {
"type": "object",
"properties": {}
}
}
}
]
# =========================
# FastAPI Lifespan
# =========================
@asynccontextmanager
async def lifespan(app: FastAPI):
global processor, model
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
device_map="auto"
)
yield
# shutdown処理(今回は不要)
app = FastAPI(
title="FunctionGemma Robot Brain",
lifespan=lifespan
)
# =========================
# API Schema
# =========================
class DecideRequest(BaseModel):
observation: Dict[str, Any]
persona: Optional[str] = "curious"
class DecideResponse(BaseModel):
action: Optional[Dict[str, Any]]
raw: str
# =========================
# Endpoints
# =========================
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_ID}
@app.post("/decide", response_model=DecideResponse)
def decide(req: DecideRequest):
system = (
"You are a small exploration robot.\n"
"You must choose exactly ONE action function.\n"
"You are curious but avoid real danger.\n"
f"Persona: {req.persona}"
)
user = f"""
Observation:
{req.observation}
Choose the next action.
"""
messages = [
{"role": "developer", "content": "You can call functions."},
{"role": "system", "content": system},
{"role": "user", "content": user}
]
inputs = processor.apply_chat_template(
messages,
tools=TOOLS,
add_generation_prompt=True,
return_tensors="pt"
)
outputs = model.generate(
**inputs.to(model.device),
max_new_tokens=128,
pad_token_id=processor.eos_token_id
)
decoded = processor.decode(
outputs[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True
)
action = parse_function_call(decoded)
return {
"action": action,
"raw": decoded
}