KGNINJA's picture
Update app.py
6820ebe verified
import torch
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
app = FastAPI()
# CORS設定
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# モデルとトークナイザーの準備
model_id = "google/functiongemma-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
class SensorInput(BaseModel):
front_distance: float
speed: float
steer: Optional[float] = 0.0 # ステアリング情報(-1=左危険, 0=中立, +1=右危険)
ml_results: Optional[List[str]] = []
@app.post("/decide")
async def decide(data: SensorInput):
"""
ロボットのセンサー入力に基づいてアクションを判断
距離ベースの判定ロジック(Gemmaは補助的に使用)
"""
distance = data.front_distance
# 距離に基づいた決定ロジックと速度
if distance < 30:
# 近い:危険 → ランダムに左右に回避(低速)
action = random.choice(["turn_left", "turn_right"])
reason = "obstacle close - evade"
speed_multiplier = 0.3 # 低速
elif distance < 80:
# 中程度の距離:慎重に判定(中速)
actions = ["move_forward", "turn_left", "turn_right"]
action = random.choice(actions)
reason = "medium distance - decide"
speed_multiplier = 0.6 # 中速
else:
# 遠い:前進可能(高速)
action = "move_forward"
reason = "clear path - advance"
speed_multiplier = 0.9 # 高速
# Gemmaに確認させるプロンプト(参考情報として)
prompt = (
f"Distance: {distance}, Action: {action}. Confirm or suggest better action.\n"
f"Return JSON: {{"
)
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=15,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
gen_text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
print(f"[{distance}] {reason} -> {action} (Gemma: {gen_text[:30]})")
except Exception as e:
print(f"Gemma error: {e}")
# 確定したアクションをJSON形式で返す
response = f'{{"action": "{action}", "distance": {distance}, "speed": {speed_multiplier}, "reason": "{reason}"}}'
return {"data": [response]}
@app.get("/")
async def root():
"""ヘルスチェック"""
return {"status": "Gemma Robot Navigation API is running"}
@app.get("/health")
async def health():
"""ステータス確認"""
return {"status": "ok", "model": "google/functiongemma-270m-it"}