Spaces:
Sleeping
Sleeping
| import torch | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import random | |
| app = FastAPI() | |
| # CORS設定 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # モデルとトークナイザーの準備 | |
| model_id = "google/functiongemma-270m-it" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| class SensorInput(BaseModel): | |
| front_distance: float | |
| speed: float | |
| steer: Optional[float] = 0.0 # ステアリング情報(-1=左危険, 0=中立, +1=右危険) | |
| ml_results: Optional[List[str]] = [] | |
| async def decide(data: SensorInput): | |
| """ | |
| ロボットのセンサー入力に基づいてアクションを判断 | |
| 距離ベースの判定ロジック(Gemmaは補助的に使用) | |
| """ | |
| distance = data.front_distance | |
| # 距離に基づいた決定ロジックと速度 | |
| if distance < 30: | |
| # 近い:危険 → ランダムに左右に回避(低速) | |
| action = random.choice(["turn_left", "turn_right"]) | |
| reason = "obstacle close - evade" | |
| speed_multiplier = 0.3 # 低速 | |
| elif distance < 80: | |
| # 中程度の距離:慎重に判定(中速) | |
| actions = ["move_forward", "turn_left", "turn_right"] | |
| action = random.choice(actions) | |
| reason = "medium distance - decide" | |
| speed_multiplier = 0.6 # 中速 | |
| else: | |
| # 遠い:前進可能(高速) | |
| action = "move_forward" | |
| reason = "clear path - advance" | |
| speed_multiplier = 0.9 # 高速 | |
| # Gemmaに確認させるプロンプト(参考情報として) | |
| prompt = ( | |
| f"Distance: {distance}, Action: {action}. Confirm or suggest better action.\n" | |
| f"Return JSON: {{" | |
| ) | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| input_len = inputs["input_ids"].shape[1] | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=15, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| gen_text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True) | |
| print(f"[{distance}] {reason} -> {action} (Gemma: {gen_text[:30]})") | |
| except Exception as e: | |
| print(f"Gemma error: {e}") | |
| # 確定したアクションをJSON形式で返す | |
| response = f'{{"action": "{action}", "distance": {distance}, "speed": {speed_multiplier}, "reason": "{reason}"}}' | |
| return {"data": [response]} | |
| async def root(): | |
| """ヘルスチェック""" | |
| return {"status": "Gemma Robot Navigation API is running"} | |
| async def health(): | |
| """ステータス確認""" | |
| return {"status": "ok", "model": "google/functiongemma-270m-it"} |