import torch from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional from transformers import AutoTokenizer, AutoModelForCausalLM import random app = FastAPI() # CORS設定 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # モデルとトークナイザーの準備 model_id = "google/functiongemma-270m-it" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ) class SensorInput(BaseModel): front_distance: float speed: float steer: Optional[float] = 0.0 # ステアリング情報(-1=左危険, 0=中立, +1=右危険) ml_results: Optional[List[str]] = [] @app.post("/decide") async def decide(data: SensorInput): """ ロボットのセンサー入力に基づいてアクションを判断 距離ベースの判定ロジック(Gemmaは補助的に使用) """ distance = data.front_distance # 距離に基づいた決定ロジックと速度 if distance < 30: # 近い:危険 → ランダムに左右に回避(低速) action = random.choice(["turn_left", "turn_right"]) reason = "obstacle close - evade" speed_multiplier = 0.3 # 低速 elif distance < 80: # 中程度の距離:慎重に判定(中速) actions = ["move_forward", "turn_left", "turn_right"] action = random.choice(actions) reason = "medium distance - decide" speed_multiplier = 0.6 # 中速 else: # 遠い:前進可能(高速) action = "move_forward" reason = "clear path - advance" speed_multiplier = 0.9 # 高速 # Gemmaに確認させるプロンプト(参考情報として) prompt = ( f"Distance: {distance}, Action: {action}. Confirm or suggest better action.\n" f"Return JSON: {{" ) try: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) input_len = inputs["input_ids"].shape[1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=15, do_sample=False, pad_token_id=tokenizer.eos_token_id ) gen_text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True) print(f"[{distance}] {reason} -> {action} (Gemma: {gen_text[:30]})") except Exception as e: print(f"Gemma error: {e}") # 確定したアクションをJSON形式で返す response = f'{{"action": "{action}", "distance": {distance}, "speed": {speed_multiplier}, "reason": "{reason}"}}' return {"data": [response]} @app.get("/") async def root(): """ヘルスチェック""" return {"status": "Gemma Robot Navigation API is running"} @app.get("/health") async def health(): """ステータス確認""" return {"status": "ok", "model": "google/functiongemma-270m-it"}