Basee_model / app.py
mohammed-aljafry's picture
Upload app.py with huggingface_hub
2bc171c verified
raw
history blame
11.2 kB
# app.py
import os
import json
import traceback
import torch
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import math
# --- استيراد من الملفات التي أنشأناها ---
from model import interfuser_baseline
from logic import (
transform, lidar_transform, InterfuserController, ControllerConfig,
Tracker, DisplayInterface, render, render_waypoints, render_self_car,
ensure_rgb, WAYPOINT_SCALE_FACTOR, T1_FUTURE_TIME, T2_FUTURE_TIME
)
# ==============================================================================
# 1. تحميل النموذج (يتم مرة واحدة)
# ==============================================================================
print("Loading the Interfuser model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = interfuser_baseline()
model_path = "model/interfuser_best_model.pth"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}. Please upload it.")
# استخدام weights_only=True لزيادة الأمان عند تحميل الملفات من مصادر غير موثوقة
try:
state_dic = torch.load(model_path, map_location=device, weights_only=True)
except:
state_dic = torch.load(model_path, map_location=device)
model.load_state_dict(state_dic)
model.to(device)
model.eval()
print("Model loaded successfully.")
# ==============================================================================
# 2. دالة التشغيل الرئيسية لـ Gradio
# ==============================================================================
def run_single_frame(
rgb_image_path: str,
rgb_left_image_path: str,
rgb_right_image_path: str,
rgb_center_image_path: str,
lidar_image_path: str,
measurements_path: str,
target_point_list: list
):
"""
تعالج إطارًا واحدًا من البيانات، وتُنشئ لوحة تحكم مرئية كاملة،
وتُرجع كلاً من الصورة والبيانات المهيكلة.
"""
try:
# ==========================================================
# 1. قراءة ومعالجة المدخلات من المسارات
# ==========================================================
if not rgb_image_path:
raise gr.Error("الرجاء توفير مسار الصورة الأمامية (RGB).")
rgb_image_pil = Image.open(rgb_image_path.name).convert("RGB")
rgb_left_pil = Image.open(rgb_left_image_path.name).convert("RGB") if rgb_left_image_path else rgb_image_pil
rgb_right_pil = Image.open(rgb_right_image_path.name).convert("RGB") if rgb_right_image_path else rgb_image_pil
rgb_center_pil = Image.open(rgb_center_image_path.name).convert("RGB") if rgb_center_image_path else rgb_image_pil
if lidar_image_path:
lidar_array = np.load(lidar_image_path.name)
if lidar_array.max() > 0:
lidar_array = (lidar_array / lidar_array.max()) * 255.0
lidar_pil = Image.fromarray(lidar_array.astype(np.uint8))
lidar_image_pil = lidar_pil.convert('RGB')
else:
lidar_image_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
rgb_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
rgb_left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
rgb_right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
rgb_center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
lidar_tensor = lidar_transform(lidar_image_pil).unsqueeze(0).to(device)
with open(measurements_path.name, 'r') as f:
measurements_dict = json.load(f)
measurements_values = [
measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0),
measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0),
measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0),
measurements_dict.get('speed', 5.0)
]
measurements_tensor = torch.tensor([measurements_values], dtype=torch.float32).to(device)
target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
inputs = {
'rgb': rgb_tensor, 'rgb_left': rgb_left_tensor, 'rgb_right': rgb_right_tensor,
'rgb_center': rgb_center_tensor, 'lidar': lidar_tensor,
'measurements': measurements_tensor, 'target_point': target_point_tensor
}
# ==========================================================
# 2. تشغيل النموذج والمعالجات اللاحقة
# ==========================================================
with torch.no_grad():
outputs = model(inputs)
traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
measurements_np = measurements_tensor[0].cpu().numpy()
pos, theta, speed = [0,0], 0, measurements_np[6]
traffic_np = traffic[0].detach().cpu().numpy().reshape(20, 20, -1)
waypoints_np = waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
tracker = Tracker()
updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, frame_num=0)
controller = InterfuserController(ControllerConfig())
steer, throttle, brake, metadata_tuple = controller.run_step(
speed=speed, waypoints=waypoints_np, junction=is_junction.sigmoid()[0, 1].item(),
traffic_light_state=traffic_light.sigmoid()[0, 0].item(),
stop_sign=stop_sign.sigmoid()[0, 1].item(), meta_data=updated_traffic
)
# ==========================================================
# 3. إنشاء التصور المرئي (Dashboard)
# ==========================================================
map_t0, counts_t0 = render(updated_traffic, t=0)
map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
wp_map = render_waypoints(waypoints_np)
self_car_map = render_self_car(loc=np.array([0,0]), ori=[math.cos(0), math.sin(0)], box=[4.0, 2.0])
map_t0 = cv2.add(cv2.add(map_t0, wp_map), self_car_map)
map_t0 = cv2.resize(map_t0, (400, 400))
map_t1 = cv2.add(ensure_rgb(map_t1), ensure_rgb(self_car_map)); map_t1 = cv2.resize(map_t1, (200, 200))
map_t2 = cv2.add(ensure_rgb(map_t2), ensure_rgb(self_car_map)); map_t2 = cv2.resize(map_t2, (200, 200))
display = DisplayInterface()
light_state = "Red" if traffic_light.sigmoid()[0,0].item() > 0.5 else "Green"
stop_sign_state = "Yes" if stop_sign.sigmoid()[0,1].item() > 0.5 else "No"
interface_data = {
'camera_view': np.array(rgb_image_pil),
'map_t0': map_t0, 'map_t1': map_t1, 'map_t2': map_t2,
'text_info': {
'Frame': 'API Frame', 'Control': f"S:{steer:.2f} T:{throttle:.2f} B:{int(brake)}",
'Light': f"L: {light_state}", 'Stop': f"St: {stop_sign_state}"
},
'object_counts': {'t0': counts_t0, 't1': counts_t1, 't2': counts_t2}
}
dashboard_image = display.run_interface(interface_data)
# ==========================================================
# 4. تجهيز وإرجاع المخرجات النهائية
# ==========================================================
result_dict = {
"predicted_waypoints": waypoints_np.tolist(),
"control_commands": {"steer": steer, "throttle": throttle, "brake": bool(brake)},
"perception": {"traffic_light_status": light_state, "stop_sign_detected": (stop_sign_state == "Yes"), "is_at_junction_prob": round(is_junction.sigmoid()[0,1].item(), 3)},
"metadata": {"speed_info": metadata_tuple[0], "perception_info": metadata_tuple[1], "stop_info": metadata_tuple[2], "safe_distance": metadata_tuple[3]}
}
return Image.fromarray(dashboard_image), result_dict
except Exception as e:
print(traceback.format_exc())
raise gr.Error(f"Error processing single frame: {e}")
# ==============================================================================
# 4. تعريف واجهة Gradio
# ==============================================================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
with gr.Tabs():
with gr.TabItem("نقطة نهاية API (إطار واحد)", id=1):
gr.Markdown("### اختبار النموذج بإدخال مباشر (Single Frame Inference)")
gr.Markdown("هذه الواجهة مخصصة للمطورين. قم برفع الملفات المطلوبة لتشغيل النموذج على إطار واحد.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### ملفات الصور والبيانات")
api_rgb_image_path = gr.File(label="RGB (Front) File (.jpg, .png)")
api_rgb_left_image_path = gr.File(label="RGB (Left) File (Optional)")
api_rgb_right_image_path = gr.File(label="RGB (Right) File (Optional)")
api_rgb_center_image_path = gr.File(label="RGB (Center) File (Optional)")
api_lidar_image_path = gr.File(label="LiDAR File (.npy, Optional)")
api_measurements_path = gr.File(label="Measurements File (.json)")
api_target_point_list = gr.JSON(label="Target Point (List [x, y])", value=[0.0, 100.0])
api_run_button = gr.Button("🚀 تشغيل إطار واحد", variant="primary")
with gr.Column(scale=2):
gr.Markdown("#### المخرجات")
api_output_image = gr.Image(label="Dashboard Result", type="pil", interactive=False)
api_output_json = gr.JSON(label="نتائج النموذج (JSON)")
api_run_button.click(
fn=run_single_frame,
inputs=[
api_rgb_image_path,
api_rgb_left_image_path,
api_rgb_right_image_path,
api_rgb_center_image_path,
api_lidar_image_path,
api_measurements_path,
api_target_point_list
],
outputs=[api_output_image, api_output_json],
api_name="run_single_frame"
)
# ==============================================================================
# 5. تشغيل التطبيق
# ==============================================================================
if __name__ == "__main__":
demo.queue().launch(debug=True)