Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import json | |
| import traceback | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import math | |
| # --- استيراد من الملفات التي أنشأناها --- | |
| from model import interfuser_baseline | |
| from logic import ( | |
| transform, lidar_transform, InterfuserController, ControllerConfig, | |
| Tracker, DisplayInterface, render, render_waypoints, render_self_car, | |
| ensure_rgb, WAYPOINT_SCALE_FACTOR, T1_FUTURE_TIME, T2_FUTURE_TIME | |
| ) | |
| # ============================================================================== | |
| # 1. تحميل النموذج (يتم مرة واحدة) | |
| # ============================================================================== | |
| print("Loading the Interfuser model...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = interfuser_baseline() | |
| model_path = "model/interfuser_best_model.pth" | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found at {model_path}. Please upload it.") | |
| # استخدام weights_only=True لزيادة الأمان عند تحميل الملفات من مصادر غير موثوقة | |
| try: | |
| state_dic = torch.load(model_path, map_location=device, weights_only=True) | |
| except: | |
| state_dic = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dic) | |
| model.to(device) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| # ============================================================================== | |
| # 2. دالة التشغيل الرئيسية لـ Gradio | |
| # ============================================================================== | |
| def run_single_frame( | |
| rgb_image_path: str, | |
| rgb_left_image_path: str, | |
| rgb_right_image_path: str, | |
| rgb_center_image_path: str, | |
| lidar_image_path: str, | |
| measurements_path: str, | |
| target_point_list: list | |
| ): | |
| """ | |
| تعالج إطارًا واحدًا من البيانات، وتُنشئ لوحة تحكم مرئية كاملة، | |
| وتُرجع كلاً من الصورة والبيانات المهيكلة. | |
| """ | |
| try: | |
| # ========================================================== | |
| # 1. قراءة ومعالجة المدخلات من المسارات | |
| # ========================================================== | |
| if not rgb_image_path: | |
| raise gr.Error("الرجاء توفير مسار الصورة الأمامية (RGB).") | |
| rgb_image_pil = Image.open(rgb_image_path.name).convert("RGB") | |
| rgb_left_pil = Image.open(rgb_left_image_path.name).convert("RGB") if rgb_left_image_path else rgb_image_pil | |
| rgb_right_pil = Image.open(rgb_right_image_path.name).convert("RGB") if rgb_right_image_path else rgb_image_pil | |
| rgb_center_pil = Image.open(rgb_center_image_path.name).convert("RGB") if rgb_center_image_path else rgb_image_pil | |
| if lidar_image_path: | |
| lidar_array = np.load(lidar_image_path.name) | |
| if lidar_array.max() > 0: | |
| lidar_array = (lidar_array / lidar_array.max()) * 255.0 | |
| lidar_pil = Image.fromarray(lidar_array.astype(np.uint8)) | |
| lidar_image_pil = lidar_pil.convert('RGB') | |
| else: | |
| lidar_image_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8)) | |
| rgb_tensor = transform(rgb_image_pil).unsqueeze(0).to(device) | |
| rgb_left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device) | |
| rgb_right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device) | |
| rgb_center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device) | |
| lidar_tensor = lidar_transform(lidar_image_pil).unsqueeze(0).to(device) | |
| with open(measurements_path.name, 'r') as f: | |
| measurements_dict = json.load(f) | |
| measurements_values = [ | |
| measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0), | |
| measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0), | |
| measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0), | |
| measurements_dict.get('speed', 5.0) | |
| ] | |
| measurements_tensor = torch.tensor([measurements_values], dtype=torch.float32).to(device) | |
| target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device) | |
| inputs = { | |
| 'rgb': rgb_tensor, 'rgb_left': rgb_left_tensor, 'rgb_right': rgb_right_tensor, | |
| 'rgb_center': rgb_center_tensor, 'lidar': lidar_tensor, | |
| 'measurements': measurements_tensor, 'target_point': target_point_tensor | |
| } | |
| # ========================================================== | |
| # 2. تشغيل النموذج والمعالجات اللاحقة | |
| # ========================================================== | |
| with torch.no_grad(): | |
| outputs = model(inputs) | |
| traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs | |
| measurements_np = measurements_tensor[0].cpu().numpy() | |
| pos, theta, speed = [0,0], 0, measurements_np[6] | |
| traffic_np = traffic[0].detach().cpu().numpy().reshape(20, 20, -1) | |
| waypoints_np = waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR | |
| tracker = Tracker() | |
| updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, frame_num=0) | |
| controller = InterfuserController(ControllerConfig()) | |
| steer, throttle, brake, metadata_tuple = controller.run_step( | |
| speed=speed, waypoints=waypoints_np, junction=is_junction.sigmoid()[0, 1].item(), | |
| traffic_light_state=traffic_light.sigmoid()[0, 0].item(), | |
| stop_sign=stop_sign.sigmoid()[0, 1].item(), meta_data=updated_traffic | |
| ) | |
| # ========================================================== | |
| # 3. إنشاء التصور المرئي (Dashboard) | |
| # ========================================================== | |
| map_t0, counts_t0 = render(updated_traffic, t=0) | |
| map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME) | |
| map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME) | |
| wp_map = render_waypoints(waypoints_np) | |
| self_car_map = render_self_car(loc=np.array([0,0]), ori=[math.cos(0), math.sin(0)], box=[4.0, 2.0]) | |
| map_t0 = cv2.add(cv2.add(map_t0, wp_map), self_car_map) | |
| map_t0 = cv2.resize(map_t0, (400, 400)) | |
| map_t1 = cv2.add(ensure_rgb(map_t1), ensure_rgb(self_car_map)); map_t1 = cv2.resize(map_t1, (200, 200)) | |
| map_t2 = cv2.add(ensure_rgb(map_t2), ensure_rgb(self_car_map)); map_t2 = cv2.resize(map_t2, (200, 200)) | |
| display = DisplayInterface() | |
| light_state = "Red" if traffic_light.sigmoid()[0,0].item() > 0.5 else "Green" | |
| stop_sign_state = "Yes" if stop_sign.sigmoid()[0,1].item() > 0.5 else "No" | |
| interface_data = { | |
| 'camera_view': np.array(rgb_image_pil), | |
| 'map_t0': map_t0, 'map_t1': map_t1, 'map_t2': map_t2, | |
| 'text_info': { | |
| 'Frame': 'API Frame', 'Control': f"S:{steer:.2f} T:{throttle:.2f} B:{int(brake)}", | |
| 'Light': f"L: {light_state}", 'Stop': f"St: {stop_sign_state}" | |
| }, | |
| 'object_counts': {'t0': counts_t0, 't1': counts_t1, 't2': counts_t2} | |
| } | |
| dashboard_image = display.run_interface(interface_data) | |
| # ========================================================== | |
| # 4. تجهيز وإرجاع المخرجات النهائية | |
| # ========================================================== | |
| result_dict = { | |
| "predicted_waypoints": waypoints_np.tolist(), | |
| "control_commands": {"steer": steer, "throttle": throttle, "brake": bool(brake)}, | |
| "perception": {"traffic_light_status": light_state, "stop_sign_detected": (stop_sign_state == "Yes"), "is_at_junction_prob": round(is_junction.sigmoid()[0,1].item(), 3)}, | |
| "metadata": {"speed_info": metadata_tuple[0], "perception_info": metadata_tuple[1], "stop_info": metadata_tuple[2], "safe_distance": metadata_tuple[3]} | |
| } | |
| return Image.fromarray(dashboard_image), result_dict | |
| except Exception as e: | |
| print(traceback.format_exc()) | |
| raise gr.Error(f"Error processing single frame: {e}") | |
| # ============================================================================== | |
| # 4. تعريف واجهة Gradio | |
| # ============================================================================== | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser") | |
| with gr.Tabs(): | |
| with gr.TabItem("نقطة نهاية API (إطار واحد)", id=1): | |
| gr.Markdown("### اختبار النموذج بإدخال مباشر (Single Frame Inference)") | |
| gr.Markdown("هذه الواجهة مخصصة للمطورين. قم برفع الملفات المطلوبة لتشغيل النموذج على إطار واحد.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### ملفات الصور والبيانات") | |
| api_rgb_image_path = gr.File(label="RGB (Front) File (.jpg, .png)") | |
| api_rgb_left_image_path = gr.File(label="RGB (Left) File (Optional)") | |
| api_rgb_right_image_path = gr.File(label="RGB (Right) File (Optional)") | |
| api_rgb_center_image_path = gr.File(label="RGB (Center) File (Optional)") | |
| api_lidar_image_path = gr.File(label="LiDAR File (.npy, Optional)") | |
| api_measurements_path = gr.File(label="Measurements File (.json)") | |
| api_target_point_list = gr.JSON(label="Target Point (List [x, y])", value=[0.0, 100.0]) | |
| api_run_button = gr.Button("🚀 تشغيل إطار واحد", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("#### المخرجات") | |
| api_output_image = gr.Image(label="Dashboard Result", type="pil", interactive=False) | |
| api_output_json = gr.JSON(label="نتائج النموذج (JSON)") | |
| api_run_button.click( | |
| fn=run_single_frame, | |
| inputs=[ | |
| api_rgb_image_path, | |
| api_rgb_left_image_path, | |
| api_rgb_right_image_path, | |
| api_rgb_center_image_path, | |
| api_lidar_image_path, | |
| api_measurements_path, | |
| api_target_point_list | |
| ], | |
| outputs=[api_output_image, api_output_json], | |
| api_name="run_single_frame" | |
| ) | |
| # ============================================================================== | |
| # 5. تشغيل التطبيق | |
| # ============================================================================== | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True) |