File size: 8,362 Bytes
9201e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import io
import json
import traceback
from uuid import uuid4

# --- FastAPI & Web server imports ---
import uvicorn
from fastapi import FastAPI, File, UploadFile, HTTPException, Security, Depends
from fastapi.security import APIKeyHeader
from fastapi.responses import JSONResponse

# --- ML & Data processing imports ---
import torch
from PIL import Image
import numpy as np

# --- استيراد من ملفات المشروع الخاصة بك ---
try:
    from model import build_interfuser_model
    from logic import (
        transform, InterfuserController, ControllerConfig,
        Tracker, WAYPOINT_SCALE_FACTOR
    )
except ImportError as e:
    print(f"Error importing from project files: {e}")
    print("Please ensure model.py and logic.py are in the same directory.")
    exit()

# ==============================================================================
#           1. إعدادات الخادم، النموذج، والأمان
# ==============================================================================
app = FastAPI(
    title="Interfuser Driving API (Secure & Stateful)",
    description="An API for driving commands with session management and API key authentication.",
    version="2.0.0"
)

# --- تحميل النموذج (يتم مرة واحدة عند بدء التشغيل) ---
MODEL_NAME = "interfuser_baseline"
WEIGHTS_PATH = os.path.join("weights", f"{MODEL_NAME}.pth")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_CONFIG = {
    "rgb_backbone_name": "r50", "embed_dim": 256, "direct_concat": True,
    'get': lambda key, default: MODEL_CONFIG.get(key, default)
}

print(f"Loading model '{MODEL_NAME}' on device '{DEVICE}'...")
if not os.path.exists(WEIGHTS_PATH):
    raise FileNotFoundError(f"Weights file not found at: {WEIGHTS_PATH}")

model = build_interfuser_model(MODEL_CONFIG)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("✅ Model loaded successfully!")

# --- إدارة الجلسات والأمان ---
SESSIONS = {}  # قاموس لتخزين حالات الجلسات: {session_id: Tracker}
API_KEY_NAME = "X-API-KEY"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)

# في تطبيق حقيقي، يجب أن تكون هذه المفاتيح في متغيرات بيئة أو خدمة إدارة أسرار
VALID_API_KEYS = {
    "your-super-secret-key-for-flutter-app", # مفتاح لتطبيق فلاتر
    "a-different-key-for-testing"          # مفتاح آخر للاختبار
}

async def get_api_key(api_key: str = Security(api_key_header)):
    """تبعية للتحقق من أن مفتاح الـ API صالح."""
    if api_key in VALID_API_KEYS:
        return api_key
    else:
        raise HTTPException(
            status_code=403, detail="Could not validate credentials or missing API Key"
        )


# ==============================================================================
#           2. تعريف نقاط نهاية الـ API (Endpoints)
# ==============================================================================

# --- حماية جميع نقاط النهاية باستخدام التبعية ---
# أي طلب لأي نقطة نهاية أدناه يجب أن يجتاز get_api_key أولاً
app.dependency_overrides[get_api_key] = get_api_key


@app.post("/sessions/create", summary="Create a new tracking session")
async def create_session(api_key: str = Depends(get_api_key)):
    """

    ينشئ جلسة تتبع جديدة ويعيد معرفًا فريدًا لها.

    هذه هي الخطوة الأولى قبل إرسال بيانات الإطارات.

    """
    session_id = str(uuid4())
    SESSIONS[session_id] = {"tracker": Tracker(), "frame_count": 0}
    print(f"New session created: {session_id}")
    return JSONResponse(content={"session_id": session_id})


@app.post("/predict/{session_id}", summary="Run a single frame prediction for a session")
async def predict(

    session_id: str,

    rgb_image: UploadFile = File(..., description="Front-facing RGB camera image."),

    measurements_json: UploadFile = File(..., description="JSON file with vehicle measurements."),

    api_key: str = Depends(get_api_key)

):
    """

    يشغل التنبؤ لإطار واحد ضمن جلسة موجودة.

    يستخدم الـ Tracker المستمر الخاص بالجلسة لتتبع الأجسام عبر الزمن.

    """
    if session_id not in SESSIONS:
        raise HTTPException(status_code=404, detail="Session not found. Please create a new session.")

    session_data = SESSIONS[session_id]
    tracker = session_data["tracker"]
    session_data["frame_count"] += 1
    current_frame = session_data["frame_count"]

    try:
        # --- قراءة ومعالجة المدخلات ---
        image_bytes = await rgb_image.read()
        measurements_string = await measurements_json.read()
        rgb_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        m_dict = json.loads(measurements_string)

        # --- تجهيز التنسورات للنموذج ---
        front_tensor = transform(rgb_pil).unsqueeze(0).to(DEVICE)
        dummy_tensor = torch.zeros_like(front_tensor)
        measurements_tensor = torch.tensor([[
            m_dict.get(k, 0.0) for k in ['x', 'y', 'theta', 'speed', 'steer', 'throttle', 'brake', 'command', 'is_junction', 'should_brake']
        ]], dtype=torch.float32).to(DEVICE)
        target_point_tensor = torch.tensor([[0.0, 100.0]], dtype=torch.float32).to(DEVICE)

        inputs = {
            'rgb': front_tensor, 'rgb_left': dummy_tensor, 'rgb_right': dummy_tensor,
            'rgb_center': dummy_tensor, 'lidar': dummy_tensor,
            'measurements': measurements_tensor, 'target_point': target_point_tensor
        }

        # --- تشغيل النموذج والتحكم ---
        with torch.no_grad():
            outputs = model(inputs)
            traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs

        traffic_np = traffic[0].detach().cpu().numpy().reshape(20, 20, -1)
        waypoints_np = waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
        
        pos = [m_dict.get('x', 0.0), m_dict.get('y', 0.0)]
        theta = m_dict.get('theta', 0.0)

        # استخدام Tracker المستمر الخاص بالجلسة
        updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, current_frame)
        
        controller = InterfuserController(ControllerConfig())
        steer, throttle, brake, _ = controller.run_step(
            m_dict.get('speed', 5.0), waypoints_np, is_junction.sigmoid()[0,1].item(),
            traffic_light.sigmoid()[0,0].item(), stop_sign.sigmoid()[0,1].item(), updated_traffic
        )

        # --- بناء وإرجاع الاستجابة ---
        control_commands = {"steer": float(steer), "throttle": float(throttle), "brake": bool(brake)}
        return JSONResponse(content={"status": "success", "control_commands": control_commands})

    except Exception as e:
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")


@app.delete("/sessions/{session_id}", summary="Delete a tracking session")
async def delete_session(session_id: str, api_key: str = Depends(get_api_key)):
    """

    يحذف جلسة تتبع لتحرير الموارد على الخادم.

    """
    if session_id in SESSIONS:
        del SESSIONS[session_id]
        print(f"Session deleted: {session_id}")
        return JSONResponse(content={"message": "Session deleted successfully."})
    raise HTTPException(status_code=404, detail="Session not found.")


# ==============================================================================
#           3. نقطة بداية تشغيل الخادم
# ==============================================================================
if __name__ == "__main__":
    print("--- Interfuser API Server ---")
    print("API documentation will be available at http://127.0.0.1:8000/docs")
    uvicorn.run(app, host="0.0.0.0", port=8000)