Spaces:
Sleeping
Sleeping
File size: 8,362 Bytes
9201e22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import io
import json
import traceback
from uuid import uuid4
# --- FastAPI & Web server imports ---
import uvicorn
from fastapi import FastAPI, File, UploadFile, HTTPException, Security, Depends
from fastapi.security import APIKeyHeader
from fastapi.responses import JSONResponse
# --- ML & Data processing imports ---
import torch
from PIL import Image
import numpy as np
# --- استيراد من ملفات المشروع الخاصة بك ---
try:
from model import build_interfuser_model
from logic import (
transform, InterfuserController, ControllerConfig,
Tracker, WAYPOINT_SCALE_FACTOR
)
except ImportError as e:
print(f"Error importing from project files: {e}")
print("Please ensure model.py and logic.py are in the same directory.")
exit()
# ==============================================================================
# 1. إعدادات الخادم، النموذج، والأمان
# ==============================================================================
app = FastAPI(
title="Interfuser Driving API (Secure & Stateful)",
description="An API for driving commands with session management and API key authentication.",
version="2.0.0"
)
# --- تحميل النموذج (يتم مرة واحدة عند بدء التشغيل) ---
MODEL_NAME = "interfuser_baseline"
WEIGHTS_PATH = os.path.join("weights", f"{MODEL_NAME}.pth")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_CONFIG = {
"rgb_backbone_name": "r50", "embed_dim": 256, "direct_concat": True,
'get': lambda key, default: MODEL_CONFIG.get(key, default)
}
print(f"Loading model '{MODEL_NAME}' on device '{DEVICE}'...")
if not os.path.exists(WEIGHTS_PATH):
raise FileNotFoundError(f"Weights file not found at: {WEIGHTS_PATH}")
model = build_interfuser_model(MODEL_CONFIG)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("✅ Model loaded successfully!")
# --- إدارة الجلسات والأمان ---
SESSIONS = {} # قاموس لتخزين حالات الجلسات: {session_id: Tracker}
API_KEY_NAME = "X-API-KEY"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
# في تطبيق حقيقي، يجب أن تكون هذه المفاتيح في متغيرات بيئة أو خدمة إدارة أسرار
VALID_API_KEYS = {
"your-super-secret-key-for-flutter-app", # مفتاح لتطبيق فلاتر
"a-different-key-for-testing" # مفتاح آخر للاختبار
}
async def get_api_key(api_key: str = Security(api_key_header)):
"""تبعية للتحقق من أن مفتاح الـ API صالح."""
if api_key in VALID_API_KEYS:
return api_key
else:
raise HTTPException(
status_code=403, detail="Could not validate credentials or missing API Key"
)
# ==============================================================================
# 2. تعريف نقاط نهاية الـ API (Endpoints)
# ==============================================================================
# --- حماية جميع نقاط النهاية باستخدام التبعية ---
# أي طلب لأي نقطة نهاية أدناه يجب أن يجتاز get_api_key أولاً
app.dependency_overrides[get_api_key] = get_api_key
@app.post("/sessions/create", summary="Create a new tracking session")
async def create_session(api_key: str = Depends(get_api_key)):
"""
ينشئ جلسة تتبع جديدة ويعيد معرفًا فريدًا لها.
هذه هي الخطوة الأولى قبل إرسال بيانات الإطارات.
"""
session_id = str(uuid4())
SESSIONS[session_id] = {"tracker": Tracker(), "frame_count": 0}
print(f"New session created: {session_id}")
return JSONResponse(content={"session_id": session_id})
@app.post("/predict/{session_id}", summary="Run a single frame prediction for a session")
async def predict(
session_id: str,
rgb_image: UploadFile = File(..., description="Front-facing RGB camera image."),
measurements_json: UploadFile = File(..., description="JSON file with vehicle measurements."),
api_key: str = Depends(get_api_key)
):
"""
يشغل التنبؤ لإطار واحد ضمن جلسة موجودة.
يستخدم الـ Tracker المستمر الخاص بالجلسة لتتبع الأجسام عبر الزمن.
"""
if session_id not in SESSIONS:
raise HTTPException(status_code=404, detail="Session not found. Please create a new session.")
session_data = SESSIONS[session_id]
tracker = session_data["tracker"]
session_data["frame_count"] += 1
current_frame = session_data["frame_count"]
try:
# --- قراءة ومعالجة المدخلات ---
image_bytes = await rgb_image.read()
measurements_string = await measurements_json.read()
rgb_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
m_dict = json.loads(measurements_string)
# --- تجهيز التنسورات للنموذج ---
front_tensor = transform(rgb_pil).unsqueeze(0).to(DEVICE)
dummy_tensor = torch.zeros_like(front_tensor)
measurements_tensor = torch.tensor([[
m_dict.get(k, 0.0) for k in ['x', 'y', 'theta', 'speed', 'steer', 'throttle', 'brake', 'command', 'is_junction', 'should_brake']
]], dtype=torch.float32).to(DEVICE)
target_point_tensor = torch.tensor([[0.0, 100.0]], dtype=torch.float32).to(DEVICE)
inputs = {
'rgb': front_tensor, 'rgb_left': dummy_tensor, 'rgb_right': dummy_tensor,
'rgb_center': dummy_tensor, 'lidar': dummy_tensor,
'measurements': measurements_tensor, 'target_point': target_point_tensor
}
# --- تشغيل النموذج والتحكم ---
with torch.no_grad():
outputs = model(inputs)
traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
traffic_np = traffic[0].detach().cpu().numpy().reshape(20, 20, -1)
waypoints_np = waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
pos = [m_dict.get('x', 0.0), m_dict.get('y', 0.0)]
theta = m_dict.get('theta', 0.0)
# استخدام Tracker المستمر الخاص بالجلسة
updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, current_frame)
controller = InterfuserController(ControllerConfig())
steer, throttle, brake, _ = controller.run_step(
m_dict.get('speed', 5.0), waypoints_np, is_junction.sigmoid()[0,1].item(),
traffic_light.sigmoid()[0,0].item(), stop_sign.sigmoid()[0,1].item(), updated_traffic
)
# --- بناء وإرجاع الاستجابة ---
control_commands = {"steer": float(steer), "throttle": float(throttle), "brake": bool(brake)}
return JSONResponse(content={"status": "success", "control_commands": control_commands})
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")
@app.delete("/sessions/{session_id}", summary="Delete a tracking session")
async def delete_session(session_id: str, api_key: str = Depends(get_api_key)):
"""
يحذف جلسة تتبع لتحرير الموارد على الخادم.
"""
if session_id in SESSIONS:
del SESSIONS[session_id]
print(f"Session deleted: {session_id}")
return JSONResponse(content={"message": "Session deleted successfully."})
raise HTTPException(status_code=404, detail="Session not found.")
# ==============================================================================
# 3. نقطة بداية تشغيل الخادم
# ==============================================================================
if __name__ == "__main__":
print("--- Interfuser API Server ---")
print("API documentation will be available at http://127.0.0.1:8000/docs")
uvicorn.run(app, host="0.0.0.0", port=8000) |