mohammed-aljafry commited on
Commit
58f0060
·
verified ·
1 Parent(s): 8c9b2cd

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +2016 -193
app.py CHANGED
@@ -1,234 +1,2057 @@
1
- # ============================================================================
2
- # app.py - الواجهة التفاعلية لمشروع Interfuser
3
- # ============================================================================
4
- # هذا الملف مسؤول فقط عن بناء وتشغيل واجهة المستخدم باستخدام Gradio.
5
- # يعتمد على:
6
- # - model_utils.py: لإدارة وتحميل النماذج.
7
- # - simulation_modules.py: لمعالجة المخرجات والتحكم والعرض.
8
- # ============================================================================
9
-
10
- import os
11
  import torch
 
 
 
 
 
 
 
 
 
 
12
  import numpy as np
13
- import cv2
 
 
 
14
  import json
15
- import logging
16
- import traceback
 
 
 
 
 
 
 
 
 
 
 
 
17
  from PIL import Image
18
 
19
- # مكتبة الواجهة الرسومية
20
- import gradio as gr
21
 
22
- # مكتبات معالجة الصور
23
- from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # --- الجزء الأول: استيراد الوحدات الخاصة بنا ---
26
- try:
27
- from model_utils import get_available_models, load_model_by_name, get_current_model
28
- from simulation_modules import (
29
- DisplayInterface,
30
- InterfuserController,
31
- ControllerConfig,
32
- render_waypoints,
33
- render_self_car,
34
- render
35
- )
36
- except ImportError as e:
37
- print(f"خطأ في الاستيراد: تأكد من وجود ملفات model_utils.py و simulation_modules.py. الخطأ: {e}")
38
- exit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # --- الجزء الثاني: الإعدادات والثوابت ---
41
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
42
 
43
- SAMPLE_DATA_DIR = "sample_data"
44
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
- # تعريف التحويلات (يجب أن تطابق ما تم استخدامه أثناء التدريب)
47
- RGB_TRANSFORM = transforms.Compose([
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  transforms.Resize((224, 224)),
49
  transforms.ToTensor(),
50
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51
  ])
52
 
53
- LIDAR_TRANSFORM = transforms.Compose([
54
- transforms.Resize((224, 224)),
55
- transforms.ToTensor()
 
 
56
  ])
57
 
58
- # --- الجزء الثالث: الكائنات العامة للواجهة ---
59
- logging.info("تهيئة مكونات الواجهة...")
60
- GLOBAL_DISPLAY_INTERFACE = DisplayInterface()
61
- controller_config = ControllerConfig()
62
- GLOBAL_CONTROLLER = InterfuserController(controller_config)
63
- logging.info(f"تم تهيئة مكونات الواجهة. الجهاز المستخدم: {DEVICE}")
64
 
65
- # --- الجزء الرابع: دوال المعالجة الرئيسية ---
 
66
 
67
- def process_and_run_inference(
68
- model: torch.nn.Module,
69
- paths: dict
70
- ):
71
- """
72
- محرك المعالجة: يعالج بيانات إطار واحد بناءً على المسارات ويشغل النموذج.
73
- """
74
- try:
75
- rgb_path = paths['rgb']
76
- if not rgb_path or not os.path.exists(rgb_path):
77
- raise FileNotFoundError("ملف الصورة الأمامية (RGB) غير موجود.")
78
 
79
- rgb_image = Image.open(rgb_path).convert("RGB")
80
-
81
- def open_optional_image(path, default_img):
82
- return Image.open(path).convert("RGB") if path and os.path.exists(path) else default_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- rgb_left_image = open_optional_image(paths.get('left'), rgb_image)
85
- rgb_right_image = open_optional_image(paths.get('right'), rgb_image)
86
- rgb_center_image = open_optional_image(paths.get('center'), rgb_image)
87
 
88
- lidar_path = paths.get('lidar')
89
- if lidar_path and os.path.exists(lidar_path):
90
- lidar_array = np.load(lidar_path)
91
- max_val = lidar_array.max()
92
- if max_val > 0: lidar_array = (lidar_array / max_val) * 255.0
93
- lidar_image = Image.fromarray(lidar_array.astype(np.uint8)).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  else:
95
- lidar_image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
 
 
96
 
97
- target_point_list = json.load(open(paths['target_point']))['value']
98
-
99
- inputs = {
100
- 'rgb': RGB_TRANSFORM(rgb_image).unsqueeze(0).to(DEVICE),
101
- 'rgb_left': RGB_TRANSFORM(rgb_left_image).unsqueeze(0).to(DEVICE),
102
- 'rgb_right': RGB_TRANSFORM(rgb_right_image).unsqueeze(0).to(DEVICE),
103
- 'rgb_center': RGB_TRANSFORM(rgb_center_image).unsqueeze(0).to(DEVICE),
104
- 'lidar': LIDAR_TRANSFORM(lidar_image).unsqueeze(0).to(DEVICE),
105
- 'measurements': torch.tensor([json.load(open(paths['measurements']))['values']], dtype=torch.float32).to(DEVICE),
106
- 'target_point': torch.tensor([target_point_list], dtype=torch.float32).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  }
108
 
109
- with torch.no_grad():
110
- outputs = model(inputs)
111
-
112
- return outputs
113
 
114
- except Exception as e:
115
- logging.error(traceback.format_exc())
116
- raise gr.Error(f"حدث خطأ أثناء معالجة البيانات: {e}")
117
 
118
 
119
- def run_simulation(scenario_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  """
121
- المنسق الرئيسي: يجهز المدخلات، يستدعي محرك المعالجة، ثم يعرض النتائج.
122
  """
123
- logging.info(f"بدء المحاكاة للسيناريو: '{scenario_name}'")
124
-
125
- model = get_current_model()
126
- if model is None:
127
- raise gr.Error("لا يوجد نموذج محمل. يرجى اختيار نموذج من القائمة أولاً.")
128
-
129
- scenario_path = os.path.join(SAMPLE_DATA_DIR, scenario_name)
130
- paths = {
131
- 'rgb': os.path.join(scenario_path, 'rgb.png'),
132
- 'left': os.path.join(scenario_path, 'rgb_left.png'),
133
- 'right': os.path.join(scenario_path, 'rgb_right.png'),
134
- 'center': os.path.join(scenario_path, 'rgb_center.png'),
135
- 'lidar': os.path.join(scenario_path, 'lidar.npy'),
136
- 'measurements': os.path.join(scenario_path, 'measurements.json'),
137
- 'target_point': os.path.join(scenario_path, 'target_point.json')
138
- }
139
-
140
- traffic, waypoints, is_junction, traffic_light, stop_sign, _ = process_and_run_inference(model, paths)
141
-
142
- pred_wp = waypoints[0].cpu().numpy()
143
- pred_traffic_map = traffic[0].sigmoid().cpu().numpy().reshape(20, 20, 7)
144
- current_speed = json.load(open(paths['measurements'])).get('values', [0]*7)[3]
145
-
146
- steer, throttle, brake, metadata = GLOBAL_CONTROLLER.run_step(
147
- current_speed=current_speed, waypoints=pred_wp,
148
- junction=torch.sigmoid(is_junction)[0,1].item(),
149
- traffic_light_state=torch.sigmoid(traffic_light)[0,0].item(),
150
- stop_sign=torch.sigmoid(stop_sign)[0,1].item(),
151
- meta_data={}
152
  )
153
-
154
- traffic_render, _ = render(pred_traffic_map)
155
- waypoints_render = render_waypoints(pred_wp)
156
- combined_map = cv2.addWeighted(traffic_render, 1.0, waypoints_render, 1.0, 0.0)
157
- final_map = render_self_car(combined_map)
158
-
159
- display_data = {
160
- 'camera_view': np.array(Image.open(paths['rgb'])),
161
- 'map_t0': final_map, 'map_t1': np.zeros_like(final_map), 'map_t2': np.zeros_like(final_map),
162
- 'text_info': {
163
- "Controller": metadata,
164
- "Brake": f"Brake Activated: {'YES' if brake else 'NO'}"
165
- }
166
- }
167
-
168
- dashboard = GLOBAL_DISPLAY_INTERFACE.run_interface(display_data)
169
- logging.info("انتهت المحاكاة بنجاح.")
170
-
171
- return dashboard, metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174
- # --- الجزء الخامس: بناء و��جهة Gradio ---
175
- with gr.Blocks(title="Interfuser Demo", theme=gr.themes.Soft()) as demo:
176
- gr.Markdown("# 🚀 Interfuser: واجهة القيادة التفاعلية")
177
- gr.Markdown("اختر النموذج والسيناريو، ثم اضغط على 'تشغيل' للمقارنة بين أداء النماذج المختلفة.")
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  try:
180
- available_models = get_available_models()
181
- available_scenarios = [d for d in os.listdir(SAMPLE_DATA_DIR) if os.path.isdir(os.path.join(SAMPLE_DATA_DIR, d))]
182
- except FileNotFoundError as e:
183
- gr.Error(f"خطأ في الإعداد: {e}. هل قمت بإنشاء مجلدي 'model' و 'sample_data'؟")
184
- available_models, available_scenarios = [], []
185
-
186
- with gr.Row():
187
- with gr.Column(scale=1, min_width=300):
188
- model_selector = gr.Dropdown(
189
- choices=available_models, label="1. اختر النموذج",
190
- value=available_models[0] if available_models else None,
191
- interactive=True
192
- )
193
- model_load_status = gr.Textbox(label="حالة تحميل النموذج", interactive=False)
194
-
195
- scenario_selector = gr.Dropdown(
196
- choices=available_scenarios, label="2. اختر سيناريو القيادة",
197
- value=available_scenarios[0] if available_scenarios else None)
198
-
199
- run_button = gr.Button("▶️ تشغيل المحاكاة", variant="primary")
200
-
201
- controller_output = gr.Textbox(label="بيانات متحكم القيادة", interactive=False)
202
-
203
- with gr.Column(scale=3):
204
- dashboard_output = gr.Image(label="لوحة المعلومات الحية (Dashboard)", interactive=False)
205
-
206
- # --- ربط الأحداث ---
207
 
208
- # 1. عند تغيير النموذج في القائمة، قم بتحميله
209
- model_selector.change(
210
- fn=load_model_by_name,
211
- inputs=[model_selector],
212
- outputs=[model_load_status] # عرض حالة التحميل للمستخدم
213
- )
 
 
 
 
 
 
 
 
 
 
214
 
215
- # 2. عند الضغط على زر التشغيل، قم بتشغيل المحاكاة
216
- if available_models and available_scenarios:
217
- run_button.click(
218
- fn=run_simulation,
219
- inputs=[scenario_selector],
220
- outputs=[dashboard_output, controller_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  )
222
- else:
223
- gr.Warning("لا يمكن العثور على نماذج أو سيناريوهات. يرجى التأكد من إعداد المجلدات بشكل صحيح.")
 
 
 
 
 
224
 
225
- # تحميل النموذج الافتراضي عند بدء تشغيل الواجهة
226
- demo.load(
227
- fn=load_model_by_name,
228
- inputs=[model_selector],
229
- outputs=[model_load_status]
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
- # --- الجزء السادس: إطلاق الواجهة ---
 
 
233
  if __name__ == "__main__":
234
- demo.launch(debug=True)
 
1
+ import traceback
 
 
 
 
 
 
 
 
 
2
  import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel, PretrainedConfig
6
+ from transformers.utils.generic import ModelOutput
7
+ from functools import partial
8
+ import math
9
+ import copy
10
+ from typing import Optional, Tuple, Union, List
11
+ from torch import Tensor
12
+ from dataclasses import dataclass
13
  import numpy as np
14
+ from timm.models.resnet import resnet50d,resnet101d, resnet26d, resnet18d
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from collections import deque, OrderedDict
17
+ import os
18
  import json
19
+ import cv2
20
+ from pathlib import Path
21
+ from torchvision import transforms
22
+ from torch.utils.data import random_split
23
+ from timm.models.registry import register_model
24
+ import gradio as gr
25
+ import zipfile
26
+ import tempfile
27
+ import shutil
28
+ import tarfile
29
+ import gdown
30
+ import time
31
+ from huggingface_hub import hf_hub_download # طريقة أفضل للتنزيل من Hub
32
+ import requests # <-- هذا هو السطر الذي يجب إضافته
33
  from PIL import Image
34
 
 
 
35
 
36
+ class HybridEmbed(nn.Module):
37
+ def __init__(
38
+ self,
39
+ backbone,
40
+ img_size=224,
41
+ patch_size=1,
42
+ feature_size=None,
43
+ in_chans=3,
44
+ embed_dim=768,
45
+ ):
46
+ super().__init__()
47
+ assert isinstance(backbone, nn.Module)
48
+ img_size = to_2tuple(img_size)
49
+ patch_size = to_2tuple(patch_size)
50
+ self.img_size = img_size
51
+ self.patch_size = patch_size
52
+ self.backbone = backbone
53
+ if feature_size is None:
54
+ with torch.no_grad():
55
+ training = backbone.training
56
+ if training:
57
+ backbone.eval()
58
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
59
+ if isinstance(o, (list, tuple)):
60
+ o = o[-1] # last feature if backbone outputs list/tuple of features
61
+ feature_size = o.shape[-2:]
62
+ feature_dim = o.shape[1]
63
+ backbone.train(training)
64
+ else:
65
+ feature_size = to_2tuple(feature_size)
66
+ if hasattr(self.backbone, "feature_info"):
67
+ feature_dim = self.backbone.feature_info.channels()[-1]
68
+ else:
69
+ feature_dim = self.backbone.num_features
70
 
71
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
72
+
73
+ def forward(self, x):
74
+ x = self.backbone(x)
75
+ if isinstance(x, (list, tuple)):
76
+ x = x[-1] # last feature if backbone outputs list/tuple of features
77
+ x = self.proj(x)
78
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
79
+ return x, global_x
80
+
81
+
82
+ class PositionEmbeddingSine(nn.Module):
83
+ """
84
+ This is a more standard version of the position embedding, very similar to the one
85
+ used by the Attention is all you need paper, generalized to work on images.
86
+ """
87
+
88
+ def __init__(
89
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
90
+ ):
91
+ super().__init__()
92
+ self.num_pos_feats = num_pos_feats
93
+ self.temperature = temperature
94
+ self.normalize = normalize
95
+ if scale is not None and normalize is False:
96
+ raise ValueError("normalize should be True if scale is passed")
97
+ if scale is None:
98
+ scale = 2 * math.pi
99
+ self.scale = scale
100
+
101
+ def forward(self, tensor):
102
+ x = tensor
103
+ bs, _, h, w = x.shape
104
+ not_mask = torch.ones((bs, h, w), device=x.device)
105
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
106
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
107
+ if self.normalize:
108
+ eps = 1e-6
109
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
110
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
111
+
112
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
113
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
114
+
115
+ pos_x = x_embed[:, :, :, None] / dim_t
116
+ pos_y = y_embed[:, :, :, None] / dim_t
117
+ pos_x = torch.stack(
118
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
119
+ ).flatten(3)
120
+ pos_y = torch.stack(
121
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
122
+ ).flatten(3)
123
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
124
+ return pos
125
+
126
+
127
+ class TransformerEncoder(nn.Module):
128
+ def __init__(self, encoder_layer, num_layers, norm=None):
129
+ super().__init__()
130
+ self.layers = _get_clones(encoder_layer, num_layers)
131
+ self.num_layers = num_layers
132
+ self.norm = norm
133
+
134
+ def forward(
135
+ self,
136
+ src,
137
+ mask: Optional[Tensor] = None,
138
+ src_key_padding_mask: Optional[Tensor] = None,
139
+ pos: Optional[Tensor] = None,
140
+ ):
141
+ output = src
142
+
143
+ for layer in self.layers:
144
+ output = layer(
145
+ output,
146
+ src_mask=mask,
147
+ src_key_padding_mask=src_key_padding_mask,
148
+ pos=pos,
149
+ )
150
+
151
+ if self.norm is not None:
152
+ output = self.norm(output)
153
+
154
+ return output
155
+
156
+
157
+ class SpatialSoftmax(nn.Module):
158
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
159
+ super().__init__()
160
+
161
+ self.data_format = data_format
162
+ self.height = height
163
+ self.width = width
164
+ self.channel = channel
165
+
166
+ if temperature:
167
+ self.temperature = nn.Parameter(torch.ones(1) * temperature)
168
+ else:
169
+ self.temperature = 1.0
170
+
171
+ pos_x, pos_y = np.meshgrid(
172
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
173
+ )
174
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
175
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
176
+ self.register_buffer("pos_x", pos_x)
177
+ self.register_buffer("pos_y", pos_y)
178
+
179
+ def forward(self, feature):
180
+ # Output:
181
+ # (N, C*2) x_0 y_0 ...
182
+
183
+ if self.data_format == "NHWC":
184
+ feature = (
185
+ feature.transpose(1, 3)
186
+ .tranpose(2, 3)
187
+ .view(-1, self.height * self.width)
188
+ )
189
+ else:
190
+ feature = feature.view(-1, self.height * self.width)
191
+
192
+ weight = F.softmax(feature / self.temperature, dim=-1)
193
+ expected_x = torch.sum(
194
+ torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
195
+ )
196
+ expected_y = torch.sum(
197
+ torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
198
+ )
199
+ expected_xy = torch.cat([expected_x, expected_y], 1)
200
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
201
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
202
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
203
+ return feature_keypoints
204
+
205
+
206
+ class MultiPath_Generator(nn.Module):
207
+ def __init__(self, in_channel, embed_dim, out_channel):
208
+ super().__init__()
209
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
210
+ self.tconv0 = nn.Sequential(
211
+ nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
212
+ nn.BatchNorm2d(256),
213
+ nn.ReLU(True),
214
+ )
215
+ self.tconv1 = nn.Sequential(
216
+ nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
217
+ nn.BatchNorm2d(256),
218
+ nn.ReLU(True),
219
+ )
220
+ self.tconv2 = nn.Sequential(
221
+ nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
222
+ nn.BatchNorm2d(192),
223
+ nn.ReLU(True),
224
+ )
225
+ self.tconv3 = nn.Sequential(
226
+ nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
227
+ nn.BatchNorm2d(64),
228
+ nn.ReLU(True),
229
+ )
230
+ self.tconv4_list = torch.nn.ModuleList(
231
+ [
232
+ nn.Sequential(
233
+ nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
234
+ nn.Tanh(),
235
+ )
236
+ for _ in range(6)
237
+ ]
238
+ )
239
+
240
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
241
+
242
+ def forward(self, x, measurements):
243
+ mask = measurements[:, :6]
244
+ mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
245
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
246
+ velocity = velocity.repeat(1, 32, 2, 2)
247
+
248
+ n, d, c = x.shape
249
+ x = x.transpose(1, 2)
250
+ x = x.view(n, -1, 2, 2)
251
+ x = torch.cat([x, velocity], dim=1)
252
+ x = self.tconv0(x)
253
+ x = self.tconv1(x)
254
+ x = self.tconv2(x)
255
+ x = self.tconv3(x)
256
+ x = self.upsample(x)
257
+ xs = []
258
+ for i in range(6):
259
+ xt = self.tconv4_list[i](x)
260
+ xs.append(xt)
261
+ xs = torch.stack(xs, dim=1)
262
+ x = torch.sum(xs * mask, dim=1)
263
+ x = self.spatial_softmax(x)
264
+ return x
265
+
266
+
267
+ class LinearWaypointsPredictor(nn.Module):
268
+ def __init__(self, input_dim, cumsum=True):
269
+ super().__init__()
270
+ self.cumsum = cumsum
271
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
272
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
273
+ self.head_relu = nn.ReLU(inplace=True)
274
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
275
+
276
+ def forward(self, x, measurements):
277
+ # input shape: n 10 embed_dim
278
+ bs, n, dim = x.shape
279
+ x = x + self.rank_embed
280
+ x = x.reshape(-1, dim)
281
+
282
+ mask = measurements[:, :6]
283
+ mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
284
+
285
+ rs = []
286
+ for i in range(6):
287
+ res = self.head_fc1_list[i](x)
288
+ res = self.head_relu(res)
289
+ res = self.head_fc2_list[i](res)
290
+ rs.append(res)
291
+ rs = torch.stack(rs, 1)
292
+ x = torch.sum(rs * mask, dim=1)
293
+
294
+ x = x.view(bs, n, 2)
295
+ if self.cumsum:
296
+ x = torch.cumsum(x, 1)
297
+ return x
298
+
299
+
300
+ class GRUWaypointsPredictor(nn.Module):
301
+ def __init__(self, input_dim, waypoints=10):
302
+ super().__init__()
303
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
304
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
305
+ self.encoder = nn.Linear(2, 64)
306
+ self.decoder = nn.Linear(64, 2)
307
+ self.waypoints = waypoints
308
+
309
+ def forward(self, x, target_point):
310
+ bs = x.shape[0]
311
+ z = self.encoder(target_point).unsqueeze(0)
312
+ output, _ = self.gru(x, z)
313
+ output = output.reshape(bs * self.waypoints, -1)
314
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
315
+ output = torch.cumsum(output, 1)
316
+ return output
317
+
318
+ class GRUWaypointsPredictorWithCommand(nn.Module):
319
+ def __init__(self, input_dim, waypoints=10):
320
+ super().__init__()
321
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
322
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
323
+ self.encoder = nn.Linear(2, 64)
324
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
325
+ self.waypoints = waypoints
326
+
327
+ def forward(self, x, target_point, measurements):
328
+ bs, n, dim = x.shape
329
+ mask = measurements[:, :6, None, None]
330
+ mask = mask.repeat(1, 1, self.waypoints, 2)
331
+
332
+ z = self.encoder(target_point).unsqueeze(0)
333
+ outputs = []
334
+ for i in range(6):
335
+ output, _ = self.grus[i](x, z)
336
+ output = output.reshape(bs * self.waypoints, -1)
337
+ output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
338
+ output = torch.cumsum(output, 1)
339
+ outputs.append(output)
340
+ outputs = torch.stack(outputs, 1)
341
+ output = torch.sum(outputs * mask, dim=1)
342
+ return output
343
+
344
+
345
+ class TransformerDecoder(nn.Module):
346
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
347
+ super().__init__()
348
+ self.layers = _get_clones(decoder_layer, num_layers)
349
+ self.num_layers = num_layers
350
+ self.norm = norm
351
+ self.return_intermediate = return_intermediate
352
+
353
+ def forward(
354
+ self,
355
+ tgt,
356
+ memory,
357
+ tgt_mask: Optional[Tensor] = None,
358
+ memory_mask: Optional[Tensor] = None,
359
+ tgt_key_padding_mask: Optional[Tensor] = None,
360
+ memory_key_padding_mask: Optional[Tensor] = None,
361
+ pos: Optional[Tensor] = None,
362
+ query_pos: Optional[Tensor] = None,
363
+ ):
364
+ output = tgt
365
+
366
+ intermediate = []
367
+
368
+ for layer in self.layers:
369
+ output = layer(
370
+ output,
371
+ memory,
372
+ tgt_mask=tgt_mask,
373
+ memory_mask=memory_mask,
374
+ tgt_key_padding_mask=tgt_key_padding_mask,
375
+ memory_key_padding_mask=memory_key_padding_mask,
376
+ pos=pos,
377
+ query_pos=query_pos,
378
+ )
379
+ if self.return_intermediate:
380
+ intermediate.append(self.norm(output))
381
+
382
+ if self.norm is not None:
383
+ output = self.norm(output)
384
+ if self.return_intermediate:
385
+ intermediate.pop()
386
+ intermediate.append(output)
387
+
388
+ if self.return_intermediate:
389
+ return torch.stack(intermediate)
390
+
391
+ return output.unsqueeze(0)
392
+
393
+
394
+ class TransformerEncoderLayer(nn.Module):
395
+ def __init__(
396
+ self,
397
+ d_model,
398
+ nhead,
399
+ dim_feedforward=2048,
400
+ dropout=0.1,
401
+ activation=nn.ReLU(),
402
+ normalize_before=False,
403
+ ):
404
+ super().__init__()
405
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
406
+ # Implementation of Feedforward model
407
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
408
+ self.dropout = nn.Dropout(dropout)
409
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
410
+
411
+ self.norm1 = nn.LayerNorm(d_model)
412
+ self.norm2 = nn.LayerNorm(d_model)
413
+ self.dropout1 = nn.Dropout(dropout)
414
+ self.dropout2 = nn.Dropout(dropout)
415
+
416
+ self.activation = activation()
417
+ self.normalize_before = normalize_before
418
+
419
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
420
+ return tensor if pos is None else tensor + pos
421
+
422
+ def forward_post(
423
+ self,
424
+ src,
425
+ src_mask: Optional[Tensor] = None,
426
+ src_key_padding_mask: Optional[Tensor] = None,
427
+ pos: Optional[Tensor] = None,
428
+ ):
429
+ q = k = self.with_pos_embed(src, pos)
430
+ src2 = self.self_attn(
431
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
432
+ )[0]
433
+ src = src + self.dropout1(src2)
434
+ src = self.norm1(src)
435
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
436
+ src = src + self.dropout2(src2)
437
+ src = self.norm2(src)
438
+ return src
439
+
440
+ def forward_pre(
441
+ self,
442
+ src,
443
+ src_mask: Optional[Tensor] = None,
444
+ src_key_padding_mask: Optional[Tensor] = None,
445
+ pos: Optional[Tensor] = None,
446
+ ):
447
+ src2 = self.norm1(src)
448
+ q = k = self.with_pos_embed(src2, pos)
449
+ src2 = self.self_attn(
450
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
451
+ )[0]
452
+ src = src + self.dropout1(src2)
453
+ src2 = self.norm2(src)
454
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
455
+ src = src + self.dropout2(src2)
456
+ return src
457
+
458
+ def forward(
459
+ self,
460
+ src,
461
+ src_mask: Optional[Tensor] = None,
462
+ src_key_padding_mask: Optional[Tensor] = None,
463
+ pos: Optional[Tensor] = None,
464
+ ):
465
+ if self.normalize_before:
466
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
467
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
468
+
469
+
470
+ class TransformerDecoderLayer(nn.Module):
471
+ def __init__(
472
+ self,
473
+ d_model,
474
+ nhead,
475
+ dim_feedforward=2048,
476
+ dropout=0.1,
477
+ activation=nn.ReLU(),
478
+ normalize_before=False,
479
+ ):
480
+ super().__init__()
481
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
482
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
483
+ # Implementation of Feedforward model
484
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
485
+ self.dropout = nn.Dropout(dropout)
486
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
487
+
488
+ self.norm1 = nn.LayerNorm(d_model)
489
+ self.norm2 = nn.LayerNorm(d_model)
490
+ self.norm3 = nn.LayerNorm(d_model)
491
+ self.dropout1 = nn.Dropout(dropout)
492
+ self.dropout2 = nn.Dropout(dropout)
493
+ self.dropout3 = nn.Dropout(dropout)
494
+
495
+ self.activation = activation()
496
+ self.normalize_before = normalize_before
497
+
498
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
499
+ return tensor if pos is None else tensor + pos
500
+
501
+ def forward_post(
502
+ self,
503
+ tgt,
504
+ memory,
505
+ tgt_mask: Optional[Tensor] = None,
506
+ memory_mask: Optional[Tensor] = None,
507
+ tgt_key_padding_mask: Optional[Tensor] = None,
508
+ memory_key_padding_mask: Optional[Tensor] = None,
509
+ pos: Optional[Tensor] = None,
510
+ query_pos: Optional[Tensor] = None,
511
+ ):
512
+ q = k = self.with_pos_embed(tgt, query_pos)
513
+ tgt2 = self.self_attn(
514
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
515
+ )[0]
516
+ tgt = tgt + self.dropout1(tgt2)
517
+ tgt = self.norm1(tgt)
518
+ tgt2 = self.multihead_attn(
519
+ query=self.with_pos_embed(tgt, query_pos),
520
+ key=self.with_pos_embed(memory, pos),
521
+ value=memory,
522
+ attn_mask=memory_mask,
523
+ key_padding_mask=memory_key_padding_mask,
524
+ )[0]
525
+ tgt = tgt + self.dropout2(tgt2)
526
+ tgt = self.norm2(tgt)
527
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
528
+ tgt = tgt + self.dropout3(tgt2)
529
+ tgt = self.norm3(tgt)
530
+ return tgt
531
+
532
+ def forward_pre(
533
+ self,
534
+ tgt,
535
+ memory,
536
+ tgt_mask: Optional[Tensor] = None,
537
+ memory_mask: Optional[Tensor] = None,
538
+ tgt_key_padding_mask: Optional[Tensor] = None,
539
+ memory_key_padding_mask: Optional[Tensor] = None,
540
+ pos: Optional[Tensor] = None,
541
+ query_pos: Optional[Tensor] = None,
542
+ ):
543
+ tgt2 = self.norm1(tgt)
544
+ q = k = self.with_pos_embed(tgt2, query_pos)
545
+ tgt2 = self.self_attn(
546
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
547
+ )[0]
548
+ tgt = tgt + self.dropout1(tgt2)
549
+ tgt2 = self.norm2(tgt)
550
+ tgt2 = self.multihead_attn(
551
+ query=self.with_pos_embed(tgt2, query_pos),
552
+ key=self.with_pos_embed(memory, pos),
553
+ value=memory,
554
+ attn_mask=memory_mask,
555
+ key_padding_mask=memory_key_padding_mask,
556
+ )[0]
557
+ tgt = tgt + self.dropout2(tgt2)
558
+ tgt2 = self.norm3(tgt)
559
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
560
+ tgt = tgt + self.dropout3(tgt2)
561
+ return tgt
562
+
563
+ def forward(
564
+ self,
565
+ tgt,
566
+ memory,
567
+ tgt_mask: Optional[Tensor] = None,
568
+ memory_mask: Optional[Tensor] = None,
569
+ tgt_key_padding_mask: Optional[Tensor] = None,
570
+ memory_key_padding_mask: Optional[Tensor] = None,
571
+ pos: Optional[Tensor] = None,
572
+ query_pos: Optional[Tensor] = None,
573
+ ):
574
+ if self.normalize_before:
575
+ return self.forward_pre(
576
+ tgt,
577
+ memory,
578
+ tgt_mask,
579
+ memory_mask,
580
+ tgt_key_padding_mask,
581
+ memory_key_padding_mask,
582
+ pos,
583
+ query_pos,
584
+ )
585
+ return self.forward_post(
586
+ tgt,
587
+ memory,
588
+ tgt_mask,
589
+ memory_mask,
590
+ tgt_key_padding_mask,
591
+ memory_key_padding_mask,
592
+ pos,
593
+ query_pos,
594
+ )
595
+
596
+
597
+ class Interfuser(nn.Module):
598
+ def __init__(
599
+ self,
600
+ img_size=224,
601
+ multi_view_img_size=112,
602
+ patch_size=8,
603
+ in_chans=3,
604
+ embed_dim=768,
605
+ enc_depth=6,
606
+ dec_depth=6,
607
+ dim_feedforward=2048,
608
+ normalize_before=False,
609
+ rgb_backbone_name="r26",
610
+ lidar_backbone_name="r26",
611
+ num_heads=8,
612
+ norm_layer=None,
613
+ dropout=0.1,
614
+ end2end=False,
615
+ direct_concat=True,
616
+ separate_view_attention=False,
617
+ separate_all_attention=False,
618
+ act_layer=None,
619
+ weight_init="",
620
+ freeze_num=-1,
621
+ with_lidar=False,
622
+ with_right_left_sensors=True,
623
+ with_center_sensor=False,
624
+ traffic_pred_head_type="det",
625
+ waypoints_pred_head="heatmap",
626
+ reverse_pos=True,
627
+ use_different_backbone=False,
628
+ use_view_embed=True,
629
+ use_mmad_pretrain=None,
630
+ ):
631
+ super().__init__()
632
+ self.traffic_pred_head_type = traffic_pred_head_type
633
+ self.num_features = (
634
+ self.embed_dim
635
+ ) = embed_dim # num_features for consistency with other models
636
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
637
+ act_layer = act_layer or nn.GELU
638
+
639
+ self.reverse_pos = reverse_pos
640
+ self.waypoints_pred_head = waypoints_pred_head
641
+ self.with_lidar = with_lidar
642
+ self.with_right_left_sensors = with_right_left_sensors
643
+ self.with_center_sensor = with_center_sensor
644
+
645
+ self.direct_concat = direct_concat
646
+ self.separate_view_attention = separate_view_attention
647
+ self.separate_all_attention = separate_all_attention
648
+ self.end2end = end2end
649
+ self.use_view_embed = use_view_embed
650
+
651
+ if self.direct_concat:
652
+ in_chans = in_chans * 4
653
+ self.with_center_sensor = False
654
+ self.with_right_left_sensors = False
655
+
656
+ if self.separate_view_attention:
657
+ self.attn_mask = build_attn_mask("seperate_view")
658
+ elif self.separate_all_attention:
659
+ self.attn_mask = build_attn_mask("seperate_all")
660
+ else:
661
+ self.attn_mask = None
662
+
663
+ if use_different_backbone:
664
+ if rgb_backbone_name == "r50":
665
+ self.rgb_backbone = resnet50d(
666
+ pretrained=True,
667
+ in_chans=in_chans,
668
+ features_only=True,
669
+ out_indices=[4],
670
+ )
671
+ elif rgb_backbone_name == "r26":
672
+ self.rgb_backbone = resnet26d(
673
+ pretrained=True,
674
+ in_chans=in_chans,
675
+ features_only=True,
676
+ out_indices=[4],
677
+ )
678
+ elif rgb_backbone_name == "r18":
679
+ self.rgb_backbone = resnet18d(
680
+ pretrained=True,
681
+ in_chans=in_chans,
682
+ features_only=True,
683
+ out_indices=[4],
684
+ )
685
+ if lidar_backbone_name == "r50":
686
+ self.lidar_backbone = resnet50d(
687
+ pretrained=False,
688
+ in_chans=in_chans,
689
+ features_only=True,
690
+ out_indices=[4],
691
+ )
692
+ elif lidar_backbone_name == "r26":
693
+ self.lidar_backbone = resnet26d(
694
+ pretrained=False,
695
+ in_chans=in_chans,
696
+ features_only=True,
697
+ out_indices=[4],
698
+ )
699
+ elif lidar_backbone_name == "r18":
700
+ self.lidar_backbone = resnet18d(
701
+ pretrained=False, in_chans=3, features_only=True, out_indices=[4]
702
+ )
703
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
704
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
705
+
706
+ if use_mmad_pretrain:
707
+ params = torch.load(use_mmad_pretrain)["state_dict"]
708
+ updated_params = OrderedDict()
709
+ for key in params:
710
+ if "backbone" in key:
711
+ updated_params[key.replace("backbone.", "")] = params[key]
712
+ self.rgb_backbone.load_state_dict(updated_params)
713
+
714
+ self.rgb_patch_embed = rgb_embed_layer(
715
+ img_size=img_size,
716
+ patch_size=patch_size,
717
+ in_chans=in_chans,
718
+ embed_dim=embed_dim,
719
+ )
720
+ self.lidar_patch_embed = lidar_embed_layer(
721
+ img_size=img_size,
722
+ patch_size=patch_size,
723
+ in_chans=3,
724
+ embed_dim=embed_dim,
725
+ )
726
+ else:
727
+ if rgb_backbone_name == "r50":
728
+ self.rgb_backbone = resnet50d(
729
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
730
+ )
731
+ elif rgb_backbone_name == "r101":
732
+ self.rgb_backbone = resnet101d(
733
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
734
+ )
735
+ elif rgb_backbone_name == "r26":
736
+ self.rgb_backbone = resnet26d(
737
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
738
+ )
739
+ elif rgb_backbone_name == "r18":
740
+ self.rgb_backbone = resnet18d(
741
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
742
+ )
743
+ embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
744
+
745
+ self.rgb_patch_embed = embed_layer(
746
+ img_size=img_size,
747
+ patch_size=patch_size,
748
+ in_chans=in_chans,
749
+ embed_dim=embed_dim,
750
+ )
751
+ self.lidar_patch_embed = embed_layer(
752
+ img_size=img_size,
753
+ patch_size=patch_size,
754
+ in_chans=in_chans,
755
+ embed_dim=embed_dim,
756
+ )
757
+
758
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
759
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
760
+
761
+ if self.end2end:
762
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
763
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
764
+ elif self.waypoints_pred_head == "heatmap":
765
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
766
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
767
+ else:
768
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
769
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
770
+
771
+ if self.end2end:
772
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
773
+ elif self.waypoints_pred_head == "heatmap":
774
+ self.waypoints_generator = MultiPath_Generator(
775
+ embed_dim + 32, embed_dim, 10
776
+ )
777
+ elif self.waypoints_pred_head == "gru":
778
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
779
+ elif self.waypoints_pred_head == "gru-command":
780
+ self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
781
+ elif self.waypoints_pred_head == "linear":
782
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
783
+ elif self.waypoints_pred_head == "linear-sum":
784
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
785
+
786
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
787
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
788
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
789
+
790
+ if self.traffic_pred_head_type == "det":
791
+ self.traffic_pred_head = nn.Sequential(
792
+ *[
793
+ nn.Linear(embed_dim + 32, 64),
794
+ nn.ReLU(),
795
+ nn.Linear(64, 7),
796
+ nn.Sigmoid(),
797
+ ]
798
+ )
799
+ elif self.traffic_pred_head_type == "seg":
800
+ self.traffic_pred_head = nn.Sequential(
801
+ *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
802
+ )
803
+
804
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
805
+
806
+ encoder_layer = TransformerEncoderLayer(
807
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
808
+ )
809
+ self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
810
+
811
+ decoder_layer = TransformerDecoderLayer(
812
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
813
+ )
814
+ decoder_norm = nn.LayerNorm(embed_dim)
815
+ self.decoder = TransformerDecoder(
816
+ decoder_layer, dec_depth, decoder_norm, return_intermediate=False
817
+ )
818
+ self.reset_parameters()
819
+
820
+ def reset_parameters(self):
821
+ nn.init.uniform_(self.global_embed)
822
+ nn.init.uniform_(self.view_embed)
823
+ nn.init.uniform_(self.query_embed)
824
+ nn.init.uniform_(self.query_pos_embed)
825
+
826
+ def forward_features(
827
+ self,
828
+ front_image,
829
+ left_image,
830
+ right_image,
831
+ front_center_image,
832
+ lidar,
833
+ measurements,
834
+ ):
835
+ features = []
836
+
837
+ # Front view processing
838
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
839
+ if self.use_view_embed:
840
+ front_image_token = (
841
+ front_image_token
842
+ + self.view_embed[:, :, 0:1, :]
843
+ + self.position_encoding(front_image_token)
844
+ )
845
+ else:
846
+ front_image_token = front_image_token + self.position_encoding(
847
+ front_image_token
848
+ )
849
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
850
+ front_image_token_global = (
851
+ front_image_token_global
852
+ + self.view_embed[:, :, 0, :]
853
+ + self.global_embed[:, :, 0:1]
854
+ )
855
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
856
+ features.extend([front_image_token, front_image_token_global])
857
+
858
+ if self.with_right_left_sensors:
859
+ # Left view processing
860
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
861
+ if self.use_view_embed:
862
+ left_image_token = (
863
+ left_image_token
864
+ + self.view_embed[:, :, 1:2, :]
865
+ + self.position_encoding(left_image_token)
866
+ )
867
+ else:
868
+ left_image_token = left_image_token + self.position_encoding(
869
+ left_image_token
870
+ )
871
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
872
+ left_image_token_global = (
873
+ left_image_token_global
874
+ + self.view_embed[:, :, 1, :]
875
+ + self.global_embed[:, :, 1:2]
876
+ )
877
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
878
+
879
+ # Right view processing
880
+ right_image_token, right_image_token_global = self.rgb_patch_embed(
881
+ right_image
882
+ )
883
+ if self.use_view_embed:
884
+ right_image_token = (
885
+ right_image_token
886
+ + self.view_embed[:, :, 2:3, :]
887
+ + self.position_encoding(right_image_token)
888
+ )
889
+ else:
890
+ right_image_token = right_image_token + self.position_encoding(
891
+ right_image_token
892
+ )
893
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
894
+ right_image_token_global = (
895
+ right_image_token_global
896
+ + self.view_embed[:, :, 2, :]
897
+ + self.global_embed[:, :, 2:3]
898
+ )
899
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
900
+
901
+ features.extend(
902
+ [
903
+ left_image_token,
904
+ left_image_token_global,
905
+ right_image_token,
906
+ right_image_token_global,
907
+ ]
908
+ )
909
+
910
+ if self.with_center_sensor:
911
+ # Front center view processing
912
+ (
913
+ front_center_image_token,
914
+ front_center_image_token_global,
915
+ ) = self.rgb_patch_embed(front_center_image)
916
+ if self.use_view_embed:
917
+ front_center_image_token = (
918
+ front_center_image_token
919
+ + self.view_embed[:, :, 3:4, :]
920
+ + self.position_encoding(front_center_image_token)
921
+ )
922
+ else:
923
+ front_center_image_token = (
924
+ front_center_image_token
925
+ + self.position_encoding(front_center_image_token)
926
+ )
927
+
928
+ front_center_image_token = front_center_image_token.flatten(2).permute(
929
+ 2, 0, 1
930
+ )
931
+ front_center_image_token_global = (
932
+ front_center_image_token_global
933
+ + self.view_embed[:, :, 3, :]
934
+ + self.global_embed[:, :, 3:4]
935
+ )
936
+ front_center_image_token_global = front_center_image_token_global.permute(
937
+ 2, 0, 1
938
+ )
939
+ features.extend([front_center_image_token, front_center_image_token_global])
940
+
941
+ if self.with_lidar:
942
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
943
+ if self.use_view_embed:
944
+ lidar_token = (
945
+ lidar_token
946
+ + self.view_embed[:, :, 4:5, :]
947
+ + self.position_encoding(lidar_token)
948
+ )
949
+ else:
950
+ lidar_token = lidar_token + self.position_encoding(lidar_token)
951
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
952
+ lidar_token_global = (
953
+ lidar_token_global
954
+ + self.view_embed[:, :, 4, :]
955
+ + self.global_embed[:, :, 4:5]
956
+ )
957
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
958
+ features.extend([lidar_token, lidar_token_global])
959
+
960
+ features = torch.cat(features, 0)
961
+ return features
962
+
963
+ def forward(self, x):
964
+ front_image = x["rgb"]
965
+ left_image = x["rgb_left"]
966
+ right_image = x["rgb_right"]
967
+ front_center_image = x["rgb_center"]
968
+ measurements = x["measurements"]
969
+ target_point = x["target_point"]
970
+ lidar = x["lidar"]
971
+
972
+ if self.direct_concat:
973
+ img_size = front_image.shape[-1]
974
+ left_image = torch.nn.functional.interpolate(
975
+ left_image, size=(img_size, img_size)
976
+ )
977
+ right_image = torch.nn.functional.interpolate(
978
+ right_image, size=(img_size, img_size)
979
+ )
980
+ front_center_image = torch.nn.functional.interpolate(
981
+ front_center_image, size=(img_size, img_size)
982
+ )
983
+ front_image = torch.cat(
984
+ [front_image, left_image, right_image, front_center_image], dim=1
985
+ )
986
+ features = self.forward_features(
987
+ front_image,
988
+ left_image,
989
+ right_image,
990
+ front_center_image,
991
+ lidar,
992
+ measurements,
993
+ )
994
+
995
+ bs = front_image.shape[0]
996
+
997
+ if self.end2end:
998
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
999
+ else:
1000
+ tgt = self.position_encoding(
1001
+ torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1002
+ )
1003
+ tgt = tgt.flatten(2)
1004
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1005
+ tgt = tgt.permute(2, 0, 1)
1006
+
1007
+ memory = self.encoder(features, mask=self.attn_mask)
1008
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1009
+
1010
+ hs = hs.permute(1, 0, 2) # Batchsize , N, C
1011
+ if self.end2end:
1012
+ waypoints = self.waypoints_generator(hs, target_point)
1013
+ return waypoints
1014
+
1015
+ if self.waypoints_pred_head != "heatmap":
1016
+ traffic_feature = hs[:, :400]
1017
+ is_junction_feature = hs[:, 400]
1018
+ traffic_light_state_feature = hs[:, 400]
1019
+ stop_sign_feature = hs[:, 400]
1020
+ waypoints_feature = hs[:, 401:411]
1021
+ else:
1022
+ traffic_feature = hs[:, :400]
1023
+ is_junction_feature = hs[:, 400]
1024
+ traffic_light_state_feature = hs[:, 400]
1025
+ stop_sign_feature = hs[:, 400]
1026
+ waypoints_feature = hs[:, 401:405]
1027
+
1028
+ if self.waypoints_pred_head == "heatmap":
1029
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1030
+ elif self.waypoints_pred_head == "gru":
1031
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
1032
+ elif self.waypoints_pred_head == "gru-command":
1033
+ waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1034
+ elif self.waypoints_pred_head == "linear":
1035
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1036
+ elif self.waypoints_pred_head == "linear-sum":
1037
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1038
+
1039
+ is_junction = self.junction_pred_head(is_junction_feature)
1040
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1041
+ stop_sign = self.stop_sign_head(stop_sign_feature)
1042
+
1043
+ velocity = measurements[:, 6:7].unsqueeze(-1)
1044
+ velocity = velocity.repeat(1, 400, 32)
1045
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1046
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
1047
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1048
+
1049
+ # ================== 3. فئة التتبع ==================
1050
+ class TrackedObject:
1051
+ def __init__(self):
1052
+ self.last_step = 0
1053
+ self.last_pos = [0, 0]
1054
+ # استخدام deque يعطينا كفاءة أفضل ويحدد حجمًا أقصى للذاكرة
1055
+ self.historical_pos = deque(maxlen=10)
1056
+ self.historical_steps = deque(maxlen=10)
1057
+ self.historical_features = deque(maxlen=10)
1058
+
1059
+ # هذه هي الدالة المفق��دة التي يجب إضافتها
1060
+ def update(self, step, object_info):
1061
+ """
1062
+ تحديث حالة الكائن بالبيانات الجديدة من الإطار الحالي.
1063
+ """
1064
+ self.last_step = step
1065
+ self.last_pos = object_info[:2]
1066
+
1067
+ # إضافة البيانات الجديدة إلى السجل التاريخي
1068
+ self.historical_pos.append(self.last_pos)
1069
+ self.historical_steps.append(step)
1070
+
1071
+ # التأكد من وجود ميزات إضافية قبل إضافتها
1072
+ if len(object_info) > 2:
1073
+ self.historical_features.append(object_info[2])
1074
+
1075
+ class Tracker:
1076
+ def __init__(self, frequency=10):
1077
+ self.tracks = []
1078
+ self.alive_ids = []
1079
+ self.frequency = frequency
1080
+
1081
+ def update_and_predict(self, det_data, pos, theta, frame_num):
1082
+ det_data_weighted = det_data * reweight_array
1083
+ detected_objects = find_peak_box(det_data_weighted)
1084
+ objects_info = []
1085
+ R = np.array([[np.cos(-theta), -np.sin(-theta)], [np.sin(-theta), np.cos(-theta)]])
1086
+
1087
+ for obj in detected_objects:
1088
+ i, j = obj['coords']
1089
+ obj_data = obj['raw_data']
1090
+
1091
+ center_y, center_x = convert_grid_to_xy(i, j)
1092
+ center_x += obj_data[1]
1093
+ center_y += obj_data[2]
1094
+
1095
+ loc = R.T.dot(np.array([center_x, center_y]))
1096
+ objects_info.append([loc[0] + pos[0], loc[1] + pos[1], obj_data[1:]]) # [x, y, features...]
1097
+
1098
+ updates_ids = self._update(objects_info, frame_num)
1099
+ speed_results, heading_results = self._predict(updates_ids)
1100
+
1101
+ for k, poi in enumerate(updates_ids):
1102
+ i, j = poi
1103
+ if heading_results[k] is not None:
1104
+ factor = MERGE_PERCENT * 0.1
1105
+ det_data[i, j, 3] = heading_results[k] * factor + det_data[i, j, 3] * (1 - factor)
1106
+ if speed_results[k] is not None:
1107
+ factor = MERGE_PERCENT * 0.1
1108
+ det_data[i, j, 6] = speed_results[k] * factor + det_data[i, j, 6] * (1 - factor)
1109
+ return det_data
1110
+
1111
+ def _update(self, objects_info, step):
1112
+ latest_ids = []
1113
+ if len(self.tracks) == 0:
1114
+ for object_info in objects_info:
1115
+ to = TrackedObject()
1116
+ to.update(step, object_info)
1117
+ self.tracks.append(to)
1118
+ latest_ids.append(len(self.tracks) - 1)
1119
+ else:
1120
+ matched_ids = set()
1121
+ for idx, object_info in enumerate(objects_info):
1122
+ min_id, min_error = -1, float('inf')
1123
+ pos_x, pos_y = object_info[:2]
1124
+ for _id in self.alive_ids:
1125
+ if _id in matched_ids:
1126
+ continue
1127
+ track_pos = self.tracks[_id].last_pos
1128
+ distance = np.sqrt((track_pos[0] - pos_x)**2 + (track_pos[1] - pos_y)**2)
1129
+ if distance < 2.0 and distance < min_error:
1130
+ min_error = distance
1131
+ min_id = _id
1132
+ if min_id != -1:
1133
+ self.tracks[min_id].update(step, objects_info[idx])
1134
+ latest_ids.append(min_id)
1135
+ matched_ids.add(min_id)
1136
+ else:
1137
+ to = TrackedObject()
1138
+ to.update(step, object_info)
1139
+ self.tracks.append(to)
1140
+ latest_ids.append(len(self.tracks) - 1)
1141
+ self.alive_ids = [i for i, track in enumerate(self.tracks) if track.last_step > step - 6]
1142
+ return latest_ids
1143
+
1144
+ def _match(self, objects_info):
1145
+ results = []
1146
+ matched_ids = set()
1147
+ for object_info in objects_info:
1148
+ min_id, min_error = -1, float('inf')
1149
+ pos_x, pos_y = object_info[:2]
1150
+ for _id in self.alive_ids:
1151
+ if _id in matched_ids:
1152
+ continue
1153
+ track_pos = self.tracks[_id].last_pos
1154
+ distance = np.sqrt((track_pos[0] - pos_x)**2 + (track_pos[1] - pos_y)**2)
1155
+ if distance < min_error:
1156
+ min_error = distance
1157
+ min_id = _id
1158
+ results.append(min_id)
1159
+ if min_id != -1:
1160
+ matched_ids.add(min_id)
1161
+ return results
1162
+
1163
+ def _predict(self, updates_ids):
1164
+ speed_results, heading_results = [], []
1165
+ for each_id in updates_ids:
1166
+ to = self.tracks[each_id]
1167
+ avg_speed, avg_heading = [], []
1168
+ for feature in to.historical_features:
1169
+ avg_speed.append(feature[2])
1170
+ avg_heading.append(feature[:2])
1171
+ if len(avg_speed) < 2:
1172
+ speed_results.append(None)
1173
+ heading_results.append(None)
1174
+ continue
1175
+ avg_speed = np.mean(avg_speed)
1176
+ avg_heading = np.mean(np.stack(avg_heading), axis=0)
1177
+ yaw_angle = get_yaw_angle(avg_heading)
1178
+ heading_results.append((4 - yaw_angle / np.pi) % 2)
1179
+ speed_results.append(avg_speed)
1180
+ return speed_results, heading_results
1181
 
 
 
1182
 
 
 
1183
 
1184
+
1185
+
1186
+ # ================== 0. تعريف PID Controller ==================
1187
+ class PIDController:
1188
+ def __init__(self, K_P=1.0, K_I=0.0, K_D=0.0, n=20):
1189
+ self._K_P = K_P
1190
+ self._K_I = K_I
1191
+ self._K_D = K_D
1192
+ self._window = deque([0 for _ in range(n)], maxlen=n)
1193
+ self._max = 0.0
1194
+ self._min = 0.0
1195
+
1196
+ def step(self, error):
1197
+ self._window.append(error)
1198
+ self._max = max(self._max, abs(error))
1199
+ self._min = -abs(self._max)
1200
+
1201
+ if len(self._window) >= 2:
1202
+ integral = np.mean(self._window)
1203
+ derivative = self._window[-1] - self._window[-2]
1204
+ else:
1205
+ integral = 0.0
1206
+ derivative = 0.0
1207
+
1208
+ return self._K_P * error + self._K_I * integral + self._K_D * derivative
1209
+ # ================== 4. فئة المتحكم ==================
1210
+ class InterfuserController:
1211
+ def __init__(self, config):
1212
+ self.turn_controller = PIDController(
1213
+ K_P=config.turn_KP,
1214
+ K_I=config.turn_KI,
1215
+ K_D=config.turn_KD,
1216
+ n=config.turn_n,
1217
+ )
1218
+ self.speed_controller = PIDController(
1219
+ K_P=config.speed_KP,
1220
+ K_I=config.speed_KI,
1221
+ K_D=config.speed_KD,
1222
+ n=config.speed_n,
1223
+ )
1224
+ self.config = config
1225
+ self.collision_buffer = np.array(config.collision_buffer)
1226
+ self.detect_threshold = config.detect_threshold
1227
+ self.stop_steps = 0
1228
+ self.forced_forward_steps = 0
1229
+ self.red_light_steps = 0
1230
+ self.block_red_light = 0
1231
+ self.in_stop_sign_effect = False
1232
+ self.block_stop_sign_distance = 0
1233
+ self.stop_sign_timer = 0
1234
+ self.stop_sign_trigger_times = 0
1235
+
1236
+ def run_step(
1237
+ self, speed, waypoints, junction, traffic_light_state, stop_sign, meta_data
1238
+ ):
1239
+ # --- تحديث حالة التوقف ---
1240
+ if speed < 0.2:
1241
+ self.stop_steps += 1
1242
+ else:
1243
+ self.stop_steps = max(0, self.stop_steps - 10)
1244
+
1245
+ if speed < 0.06 and self.in_stop_sign_effect:
1246
+ self.in_stop_sign_effect = False
1247
+
1248
+ if junction < 0.3:
1249
+ self.stop_sign_trigger_times = 0
1250
+
1251
+ if traffic_light_state > 0.7:
1252
+ self.red_light_steps += 1
1253
+ else:
1254
+ self.red_light_steps = 0
1255
+
1256
+ if self.red_light_steps > 1000:
1257
+ self.block_red_light = 80
1258
+ self.red_light_steps = 0
1259
+
1260
+ if self.block_red_light > 0:
1261
+ self.block_red_light -= 1
1262
+ traffic_light_state = 0.01
1263
+
1264
+ if stop_sign < 0.6 and self.block_stop_sign_distance < 0.1:
1265
+ self.in_stop_sign_effect = True
1266
+ self.block_stop_sign_distance = 2.0
1267
+ self.stop_sign_trigger_times = 3
1268
+
1269
+ self.block_stop_sign_distance = max(
1270
+ 0, self.block_stop_sign_distance - 0.05 * speed
1271
+ )
1272
+
1273
+ if self.block_stop_sign_distance < 0.1:
1274
+ if self.stop_sign_trigger_times > 0:
1275
+ self.block_stop_sign_distance = 2.0
1276
+ self.stop_sign_trigger_times -= 1
1277
+ self.in_stop_sign_effect = True
1278
+
1279
+ # --- حساب زاوية الانعطاف ---
1280
+ aim = (waypoints[1] + waypoints[0]) / 2.0
1281
+ angle = np.degrees(np.pi / 2 - np.arctan2(aim[1], aim[0])) / 90
1282
+ if speed < 0.01:
1283
+ angle = 0
1284
+ steer = self.turn_controller.step(angle)
1285
+ steer = np.clip(steer, -1.0, 1.0)
1286
+
1287
+ brake = False
1288
+ throttle = 0.0
1289
+ desired_speed = 0.0
1290
+
1291
+ downsampled_waypoints = downsample_waypoints(waypoints)
1292
+
1293
+ d_0 = get_max_safe_distance(
1294
+ meta_data,
1295
+ downsampled_waypoints,
1296
+ t=0,
1297
+ collision_buffer=self.collision_buffer,
1298
+ threshold=self.detect_threshold,
1299
+ )
1300
+ d_05 = get_max_safe_distance(
1301
+ meta_data,
1302
+ downsampled_waypoints,
1303
+ t=0.5,
1304
+ collision_buffer=self.collision_buffer,
1305
+ threshold=self.detect_threshold,
1306
+ )
1307
+ d_075 = get_max_safe_distance(
1308
+ meta_data,
1309
+ downsampled_waypoints,
1310
+ t=0.75,
1311
+ collision_buffer=self.collision_buffer,
1312
+ threshold=self.detect_threshold,
1313
+ )
1314
+ d_1 = get_max_safe_distance(
1315
+ meta_data,
1316
+ downsampled_waypoints,
1317
+ t=1,
1318
+ collision_buffer=self.collision_buffer,
1319
+ threshold=self.detect_threshold,
1320
+ )
1321
+ d_15 = get_max_safe_distance(
1322
+ meta_data,
1323
+ downsampled_waypoints,
1324
+ t=1.5,
1325
+ collision_buffer=self.collision_buffer,
1326
+ threshold=self.detect_threshold,
1327
+ )
1328
+ d_2 = get_max_safe_distance(
1329
+ meta_data,
1330
+ downsampled_waypoints,
1331
+ t=2,
1332
+ collision_buffer=self.collision_buffer,
1333
+ threshold=self.detect_threshold,
1334
+ )
1335
+
1336
+ d_05 = min(d_0, d_05, d_075)
1337
+ d_1 = min(d_05, d_075, d_15, d_2)
1338
+
1339
+ safe_dis = min(d_05, d_1)
1340
+ d_0 = max(0, d_0 - 2.0)
1341
+ d_05 = max(0, d_05 - 2.0)
1342
+ d_1 = max(0, d_1 - 2.0)
1343
+
1344
+ # --- تفعيل الفرملة فقط إذا كانت الإشارة حمراء أو هناك علامة Stop ---
1345
+ if traffic_light_state > 0.5:
1346
+ brake = True
1347
+ desired_speed = 0.0
1348
+ elif stop_sign > 0.6 and traffic_light_state <= 0.5:
1349
+ if self.stop_sign_timer < 20:
1350
+ brake = True
1351
+ desired_speed = 0.0
1352
+ self.stop_sign_timer += 1
1353
+ else:
1354
+ brake = False
1355
+ desired_speed = max(0, min(self.config.max_speed, speed + 0.2))
1356
+ else:
1357
+ brake = False
1358
+ desired_speed = max(0, min(self.config.max_speed, speed + 0.2))
1359
+
1360
+ delta = np.clip(desired_speed - speed, 0.0, self.config.clip_delta)
1361
+ throttle = self.speed_controller.step(delta)
1362
+ throttle = np.clip(throttle, 0.0, self.config.max_throttle)
1363
+
1364
+ # --- إذا كانت السرعة أعلى من 1.1 مرة السرعة المستهدفة، نفرم ---
1365
+ if speed > desired_speed * self.config.brake_ratio:
1366
+ brake = True
1367
+
1368
+ # --- إعداد معلومات التشخيص ---
1369
+ meta_info_1 = f"speed: {speed:.2f}, target_speed: {desired_speed:.2f}"
1370
+ meta_info_2 = f"on_road_prob: {junction:.2f}, red_light_prob: {traffic_light_state:.2f}, stop_sign_prob: {1 - stop_sign:.2f}"
1371
+ meta_info_3 = f"stop_steps: {self.stop_steps}, block_stop_sign_distance: {self.block_stop_sign_distance:.1f}"
1372
+
1373
+ # --- حالة خاصة بعد فترة طويلة من التوقف ---
1374
+ if self.stop_steps > 1200:
1375
+ self.forced_forward_steps = 12
1376
+ self.stop_steps = 0
1377
+ if self.forced_forward_steps > 0:
1378
+ throttle = 0.8
1379
+ brake = False
1380
+ self.forced_forward_steps -= 1
1381
+ if self.in_stop_sign_effect:
1382
+ throttle = 0
1383
+ brake = True
1384
+
1385
+ return steer, throttle, brake, (meta_info_1, meta_info_2, meta_info_3, safe_dis)
1386
+
1387
+
1388
+ class ControllerConfig:
1389
+ turn_KP, turn_KI, turn_KD, turn_n = 1.0, 0.1, 0.1, 20
1390
+ speed_KP, speed_KI, speed_KD, speed_n = 0.5, 0.05, 0.1, 20
1391
+ max_speed, max_throttle, clip_delta = 6.0, 0.75, 0.25
1392
+ collision_buffer, detect_threshold = [0.0, 0.0], 0.04
1393
+ brake_speed, brake_ratio = 0.4, 1.1
1394
+
1395
+
1396
+ # ================== 5. واجهة العرض ==================
1397
+ class DisplayInterface:
1398
+ def __init__(self, width=1200, height=600):
1399
+ self._width = width
1400
+ self._height = height
1401
+
1402
+ def run_interface(self, data):
1403
+ dashboard = np.zeros((self._height, self._width, 3), dtype=np.uint8)
1404
+ font = cv2.FONT_HERSHEY_SIMPLEX
1405
+ dashboard[:, :800] = cv2.resize(data.get('camera_view'), (800, 600))
1406
+ dashboard[:400, 800:1200] = cv2.resize(data['map_t0'], (400, 400))
1407
+ dashboard[400:600, 800:1000] = cv2.resize(data['map_t1'], (200, 200))
1408
+ dashboard[400:600, 1000:1200] = cv2.resize(data['map_t2'], (200, 200))
1409
+
1410
+ # خطوط فصل
1411
+ cv2.line(dashboard, (800, 0), (800, 600), (255, 255, 255), 2)
1412
+ cv2.line(dashboard, (800, 400), (1200, 400), (255, 255, 255), 2)
1413
+ cv2.line(dashboard, (1000, 400), (1000, 600), (255, 255, 255), 2)
1414
+
1415
+ y_pos = 40
1416
+ for key, text in data['text_info'].items():
1417
+ cv2.putText(dashboard, text, (820, y_pos), font, 0.6, (255, 255, 255), 1)
1418
+ y_pos += 30
1419
+
1420
+ y_pos += 10
1421
+ for t, counts in data['object_counts'].items():
1422
+ count_str = f"{t}: C={counts['car']} B={counts['bike']} P={counts['pedestrian']}"
1423
+ cv2.putText(dashboard, count_str, (820, y_pos), font, 0.5, (255, 255, 255), 1)
1424
+ y_pos += 20
1425
+
1426
+ cv2.putText(dashboard, "t0", (1160, 30), font, 0.8, (0, 255, 255), 2)
1427
+ cv2.putText(dashboard, "t1", (960, 430), font, 0.8, (0, 255, 255), 2)
1428
+ cv2.putText(dashboard, "t2", (1160, 430), font, 0.8, (0, 255, 255), 2)
1429
+
1430
+ return dashboard
1431
+
1432
+ # --- تحديد التحوّلات ---
1433
+ transform = transforms.Compose([
1434
+ # الخطوة الأولى الآن هي تغيير الحجم مباشرة
1435
  transforms.Resize((224, 224)),
1436
  transforms.ToTensor(),
1437
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
1438
  ])
1439
 
1440
+ lidar_transform = transforms.Compose([
1441
+ # الخطوة الأولى الآن هي تغيير الحجم مباشرة
1442
+ transforms.Resize((112, 112)),
1443
+ transforms.ToTensor(),
1444
+ transforms.Normalize(mean=[0.5], std=[0.5]),
1445
  ])
1446
 
1447
+ class LMDriveDataset(Dataset):
1448
+ def __init__(self, data_dir, transform=None, lidar_transform=None):
1449
+ self.data_dir = Path(data_dir)
1450
+ self.transform = transform
1451
+ self.lidar_transform = lidar_transform
1452
+ self.samples = []
1453
 
1454
+ measurement_dir = self.data_dir / "measurements"
1455
+ image_dir = self.data_dir / "rgb_full"
1456
 
1457
+ measurement_files = sorted([f for f in os.listdir(measurement_dir) if f.endswith(".json")])
1458
+ image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".jpg")])
 
 
 
 
 
 
 
 
 
1459
 
1460
+ num_samples = min(len(measurement_files), len(image_files))
1461
+
1462
+ for i in range(num_samples):
1463
+ frame_id = i
1464
+ measurement_path = str(measurement_dir / f"{frame_id:04d}.json")
1465
+ image_name = f"{frame_id:04d}.jpg"
1466
+ image_path = str(image_dir / image_name)
1467
+
1468
+ if not os.path.exists(measurement_path) or not os.path.exists(image_path):
1469
+ continue
1470
+
1471
+ with open(measurement_path, "r") as f:
1472
+ measurements_data = json.load(f)
1473
+
1474
+ self.samples.append({
1475
+ "image_path": image_path,
1476
+ "measurement_path": measurement_path,
1477
+ "frame_id": frame_id,
1478
+ "measurements": measurements_data
1479
+ })
1480
+
1481
+ def __len__(self):
1482
+ return len(self.samples)
1483
 
1484
+ def __getitem__(self, idx):
1485
+ sample = self.samples[idx]
 
1486
 
1487
+ # قراءة الصورة الكاملة (2400x800)
1488
+ full_image = cv2.imread(sample["image_path"])
1489
+ if full_image is None:
1490
+ raise ValueError(f"Failed to load image: {sample['image_path']}")
1491
+ full_image = cv2.cvtColor(full_image, cv2.COLOR_BGR2RGB)
1492
+
1493
+ # تقسيم الصورة إلى أجزاء (كل جزء 600x800)
1494
+ front_image = full_image[:600, :800] # الجزء الأول
1495
+ left_image = full_image[600:1200, :800] # الجزء الثاني
1496
+ right_image = full_image[1200:1800, :800] # الجزء الثالث
1497
+ center_image = full_image[1800:2400, :800]# الجزء الرابع
1498
+
1499
+ # تطبيق التحويل على كل صورة
1500
+ front_image_tensor = self.transform(front_image)
1501
+ left_image_tensor = self.transform(left_image)
1502
+ right_image_tensor = self.transform(right_image)
1503
+ center_image_tensor = self.transform(center_image)
1504
+
1505
+ # تحميل الليدار
1506
+ lidar_path = str(self.data_dir / "lidar" / f"{sample['frame_id']:04d}.png")
1507
+ lidar = cv2.imread(lidar_path)
1508
+
1509
+ if lidar is None:
1510
+ lidar = np.zeros((112, 112, 3), dtype=np.uint8) # مكان فارغ
1511
  else:
1512
+ if len(lidar.shape) == 2:
1513
+ lidar = cv2.cvtColor(lidar, cv2.COLOR_GRAY2BGR)
1514
+ lidar = cv2.cvtColor(lidar, cv2.COLOR_BGR2RGB)
1515
 
1516
+ lidar_tensor = self.lidar_transform(lidar)
1517
+
1518
+ # استخراج القياسات
1519
+ measurements_data = sample["measurements"]
1520
+
1521
+ x = measurements_data.get("x", 0.0)
1522
+ y = measurements_data.get("y", 0.0)
1523
+ theta = measurements_data.get("theta", 0.0)
1524
+ speed = measurements_data.get("speed", 0.0)
1525
+ steer = measurements_data.get("steer", 0.0)
1526
+ throttle = measurements_data.get("throttle", 0.0)
1527
+ brake = int(measurements_data.get("brake", False))
1528
+ command = measurements_data.get("command", 0)
1529
+ is_junction = int(measurements_data.get("is_junction", False))
1530
+ should_brake = int(measurements_data.get("should_brake", 0))
1531
+ x_command = measurements_data.get("x_command", 0.0)
1532
+ y_command = measurements_data.get("y_command", 0.0)
1533
+
1534
+ target_point = torch.tensor([x_command, y_command], dtype=torch.float32)
1535
+
1536
+ measurements = torch.tensor(
1537
+ [x, y, theta, speed, steer, throttle, brake, command, is_junction, should_brake],
1538
+ dtype=torch.float32
1539
+ )
1540
+
1541
+ return {
1542
+ "rgb": front_image_tensor,
1543
+ "rgb_left": left_image_tensor,
1544
+ "rgb_right": right_image_tensor,
1545
+ "rgb_center": center_image_tensor,
1546
+ "lidar": lidar_tensor,
1547
+ "measurements": measurements,
1548
+ "target_point": target_point
1549
  }
1550
 
 
 
 
 
1551
 
 
 
 
1552
 
1553
 
1554
+ SAVE_VIDEO = True
1555
+ FPS = 10
1556
+ WAYPOINT_SCALE_FACTOR = 5.0
1557
+ T1_FUTURE_TIME = 1.0
1558
+ T2_FUTURE_TIME = 2.0
1559
+ TRACKER_FREQUENCY = 10
1560
+ MERGE_PERCENT = 0.4
1561
+ PIXELS_PER_METER = 8
1562
+ MAX_DISTANCE = 32
1563
+ IMG_SIZE = MAX_DISTANCE * PIXELS_PER_METER * 2
1564
+ EGO_CAR_X = IMG_SIZE // 2
1565
+ EGO_CAR_Y = IMG_SIZE - (4.0 * PIXELS_PER_METER)
1566
+ reweight_array = np.ones((20, 20, 7))
1567
+ last_valid_waypoints = None
1568
+ last_valid_theta = 0.0
1569
+
1570
+ def to_2tuple(x):
1571
+ if isinstance(x, tuple): return x
1572
+ return (x, x)
1573
+
1574
+
1575
+ def _get_clones(module, N):
1576
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
1577
+
1578
+
1579
+ def _get_activation_fn(activation):
1580
+ """Return an activation function given a string"""
1581
+ if activation == "relu":
1582
+ return F.relu
1583
+ if activation == "gelu":
1584
+ return F.gelu
1585
+ if activation == "glu":
1586
+ return F.glu
1587
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
1588
+
1589
+
1590
+ def build_attn_mask(mask_type):
1591
+ mask = torch.ones((151, 151), dtype=torch.bool).cuda()
1592
+ if mask_type == "seperate_all":
1593
+ mask[:50, :50] = False
1594
+ mask[50:67, 50:67] = False
1595
+ mask[67:84, 67:84] = False
1596
+ mask[84:101, 84:101] = False
1597
+ mask[101:151, 101:151] = False
1598
+ elif mask_type == "seperate_view":
1599
+ mask[:50, :50] = False
1600
+ mask[50:67, 50:67] = False
1601
+ mask[67:84, 67:84] = False
1602
+ mask[84:101, 84:101] = False
1603
+ mask[101:151, :] = False
1604
+ mask[:, 101:151] = False
1605
+ return mask
1606
+
1607
+
1608
+ def get_yaw_angle(forward_vector):
1609
+ forward_vector = forward_vector / np.linalg.norm(forward_vector)
1610
+ yaw = math.atan2(forward_vector[1], forward_vector[0])
1611
+ return yaw
1612
+
1613
+
1614
+ @register_model
1615
+ def interfuser_baseline(**kwargs):
1616
+ model = Interfuser(
1617
+ enc_depth=6,
1618
+ dec_depth=6,
1619
+ embed_dim=256,
1620
+ rgb_backbone_name="r50",
1621
+ lidar_backbone_name="r18",
1622
+ waypoints_pred_head="gru",
1623
+ use_different_backbone=True,
1624
+ )
1625
+ # model.save_pretrained("/content/t")
1626
+ return model
1627
+
1628
+ def ensure_rgb(image):
1629
+ """تحويل الصورة إلى RGB إذا كانت grayscale."""
1630
+ if len(image.shape) == 2 or image.shape[2] == 1:
1631
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
1632
+ return image
1633
+ def process_camera_image(tensor_image):
1634
+ """تحويل صورة الكاميرا من Tensor إلى NumPy Array."""
1635
+ image_np = tensor_image.permute(1, 2, 0).cpu().numpy()
1636
+ image_np = (image_np * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
1637
+ image_np = np.clip(image_np, 0, 1)
1638
+ return (image_np * 255).astype(np.uint8)[:, :, ::-1] # BGR
1639
+
1640
+
1641
+ def convert_grid_to_xy(i, j):
1642
+ """تحويل الشبكة إلى إحداثيات x, y."""
1643
+ return (j - 9.5) * 1.6, (19.5 - i) * 1.6
1644
+
1645
+
1646
+ def add_rect(img, loc, ori, box, value, color):
1647
  """
1648
+ إضافة مستطيل إلى الخريطة.
1649
  """
1650
+ center_x = int(loc[0] * PIXELS_PER_METER + MAX_DISTANCE * PIXELS_PER_METER)
1651
+ center_y = int(loc[1] * PIXELS_PER_METER + MAX_DISTANCE * PIXELS_PER_METER)
1652
+
1653
+ size_px = (
1654
+ int(box[0] * PIXELS_PER_METER),
1655
+ int(box[1] * PIXELS_PER_METER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1656
  )
1657
+
1658
+ angle_deg = -np.degrees(math.atan2(ori[1], ori[0]))
1659
+
1660
+ box_points = cv2.boxPoints(((center_x, center_y), size_px, angle_deg))
1661
+ box_points = np.int32(box_points)
1662
+
1663
+ adjusted_color = [int(x * value) for x in color]
1664
+ cv2.fillConvexPoly(img, box_points, adjusted_color)
1665
+ return img
1666
+
1667
+ def find_peak_box(data):
1668
+ """
1669
+ اكتشاف القمم في البيانات وتصنيفها.
1670
+ """
1671
+ det_data = np.zeros((22, 22, 7))
1672
+ det_data[1:21, 1:21] = data
1673
+ detected_objects = []
1674
+
1675
+ for i in range(1, 21):
1676
+ for j in range(1, 21):
1677
+ if det_data[i, j, 0] > 0.6 and (
1678
+ det_data[i, j, 0] > det_data[i, j - 1, 0]
1679
+ and det_data[i, j, 0] > det_data[i, j + 1, 0]
1680
+ and det_data[i, j, 0] > det_data[i - 1, j, 0]
1681
+ and det_data[i, j, 0] > det_data[i + 1, j, 0]
1682
+ ):
1683
+ length = det_data[i, j, 4]
1684
+ width = det_data[i, j, 5]
1685
+ confidence = det_data[i, j, 0]
1686
+
1687
+ obj_type = 'unknown'
1688
+ if length > 4.0:
1689
+ obj_type = 'car'
1690
+ elif length / width > 1.5:
1691
+ obj_type = 'bike'
1692
+ else:
1693
+ obj_type = 'pedestrian'
1694
+
1695
+ detected_objects.append({
1696
+ 'coords': (i - 1, j - 1),
1697
+ 'type': obj_type,
1698
+ 'confidence': confidence,
1699
+ 'raw_data': det_data[i, j]
1700
+ })
1701
+
1702
+ return detected_objects
1703
+
1704
+
1705
+ def render(det_data, t=0):
1706
+ """
1707
+ رسم كائنات الكشف على الخريطة BEV.
1708
+ """
1709
+ CLASS_COLORS = {'car': (0, 0, 255), 'bike': (0, 255, 0), 'pedestrian': (255, 0, 0), 'unknown': (128, 128, 128)}
1710
+ det_weighted = det_data * reweight_array
1711
+ detected_objects = find_peak_box(det_weighted)
1712
+ counts = {cls: 0 for cls in CLASS_COLORS.keys()}
1713
+ [counts.update({obj['type']: counts.get(obj['type'], 0) + 1}) for obj in detected_objects]
1714
+ img = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
1715
+
1716
+ for obj in detected_objects:
1717
+ i, j = obj['coords']
1718
+ obj_data = obj['raw_data']
1719
+ speed = obj_data[6]
1720
+ center_x, center_y = convert_grid_to_xy(i, j)
1721
+ theta = obj_data[3] * np.pi
1722
+ ori = np.array([math.cos(theta), math.sin(theta)])
1723
+ loc_x = center_x + obj_data[1] + t * speed * ori[0]
1724
+ loc_y = center_y + obj_data[2] - t * speed * ori[1]
1725
+ box = np.array([obj_data[4], obj_data[5]])
1726
+ if obj['type'] == 'pedestrian':
1727
+ box *= 1.5
1728
+ add_rect(
1729
+ img,
1730
+ loc=np.array([loc_x, loc_y]),
1731
+ ori=ori,
1732
+ box=box,
1733
+ value=obj['confidence'],
1734
+ color=CLASS_COLORS[obj['type']]
1735
+ )
1736
+ return img, counts
1737
+
1738
+
1739
+ def render_self_car(loc, ori, box, pixels_per_meter=PIXELS_PER_METER):
1740
+ """
1741
+ رسم السيارة الذاتية على الخريطة BEV.
1742
+ Args:
1743
+ loc: موقع السيارة [x, y] في النظام العالمي.
1744
+ ori: اتجاه السيارة [cos(theta), sin(theta)].
1745
+ box: أبعاد السيارة [طول, عرض].
1746
+ pixels_per_meter: عدد البكسلات لكل متر.
1747
+ Returns:
1748
+ self_car_map: خريطة السيارة ذاتية القيادة (RGB - 3 قنوات).
1749
+ """
1750
+ img = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
1751
+ center_x = int(loc[0] * pixels_per_meter + MAX_DISTANCE * pixels_per_meter)
1752
+ center_y = int(loc[1] * pixels_per_meter + MAX_DISTANCE * pixels_per_meter)
1753
+ size_px = (
1754
+ int(box[0] * pixels_per_meter),
1755
+ int(box[1] * pixels_per_meter)
1756
+ )
1757
+ angle_deg = -np.degrees(math.atan2(ori[1], ori[0]))
1758
+ box_points = cv2.boxPoints(((center_x, center_y), size_px, angle_deg))
1759
+ box_points = np.int32(box_points)
1760
+ ego_color = (0, 255, 255) # أصفر
1761
+ cv2.fillConvexPoly(img, box_points, ego_color)
1762
+ return img # ← نرجع الصورة بأكملها وليس جزءًا منها
1763
+
1764
+ def render_waypoints(waypoints, pixels_per_meter=PIXELS_PER_METER):
1765
+ global last_valid_waypoints
1766
+ img = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
1767
+ current_waypoints = waypoints
1768
+ if waypoints is not None and len(waypoints) > 2:
1769
+ last_valid_waypoints = waypoints
1770
+ else:
1771
+ current_waypoints = last_valid_waypoints
1772
+ if current_waypoints is None:
1773
+ return img
1774
+ origin_x, origin_y = EGO_CAR_X, EGO_CAR_Y
1775
+ for i, point in enumerate(current_waypoints):
1776
+ px = int(origin_x + point[1] * pixels_per_meter)
1777
+ py = int(origin_y - point[0] * pixels_per_meter)
1778
+ color = (0, 0, 255) if i == len(current_waypoints) - 1 else (0, 255, 0)
1779
+ cv2.circle(img, (px, py), 4, color, -1)
1780
+ return img
1781
+
1782
+
1783
+ def collision_detections(map1, map2, threshold=0.04):
1784
+ """
1785
+ تحقق من وجود تداخل بين خريطة البيئة ونموذج السيارة.
1786
+ """
1787
+ print("map1 shape:", map1.shape)
1788
+ print("map2 shape:", map2.shape)
1789
+
1790
+ # تحويل map2 إلى grayscale إذا كانت تحتوي على 3 قنوات (RGB)
1791
+ if len(map2.shape) == 3 and map2.shape[2] == 3:
1792
+ map2 = cv2.cvtColor(map2, cv2.COLOR_BGR2GRAY)
1793
+
1794
+ # التأكد من أن map1 و map2 لها نفس الأبعاد
1795
+ assert map1.shape == map2.shape
1796
+
1797
+ overlap_map = (map1 > 0.01) & (map2 > 0.01)
1798
+ ratio = float(np.sum(overlap_map)) / np.sum(map2 > 0)
1799
+ return ratio < threshold
1800
+
1801
+ def get_max_safe_distance(meta_data, downsampled_waypoints, t, collision_buffer, threshold):
1802
+ """
1803
+ حساب أقصى مسافة آمنة قبل حدوث تصادم.
1804
+ """
1805
+ surround_map = meta_data.reshape(20, 20, 7)[..., :3][..., 0]
1806
+ if np.sum(surround_map) < 1:
1807
+ return np.linalg.norm(downsampled_waypoints[-3])
1808
+ hero_bounding_box = np.array([2.45, 1.0]) + collision_buffer
1809
+ safe_distance = 0.0
1810
+ for i in range(len(downsampled_waypoints) - 2):
1811
+ aim = (downsampled_waypoints[i + 1] + downsampled_waypoints[i + 2]) / 2.0
1812
+ loc = downsampled_waypoints[i]
1813
+ ori = aim - loc
1814
+ self_car_map = render_self_car(loc=loc, ori=ori, box=hero_bounding_box, pixels_per_meter=PIXELS_PER_METER)
1815
+ # تصغير الخريطة والتحويل إلى grayscale
1816
+ self_car_map_resized = cv2.resize(self_car_map, (20, 20))
1817
+ self_car_map_gray = cv2.cvtColor(self_car_map_resized, cv2.COLOR_BGR2GRAY)
1818
+ if not collision_detections(surround_map, self_car_map_gray, threshold):
1819
+ break
1820
+ safe_distance = max(safe_distance, np.linalg.norm(loc))
1821
+ return safe_distance
1822
+
1823
+ def downsample_waypoints(waypoints, precision=0.2):
1824
+ """
1825
+ تقليل عدد نقاط المسار.
1826
+ """
1827
+ downsampled_waypoints = []
1828
+ last_waypoint = np.array([0.0, 0.0])
1829
+ for i in range(len(waypoints)):
1830
+ now_waypoint = waypoints[i]
1831
+ dis = np.linalg.norm(now_waypoint - last_waypoint)
1832
+ if dis > precision:
1833
+ interval = int(dis / precision)
1834
+ move_vector = (now_waypoint - last_waypoint) / (interval + 1)
1835
+ for j in range(interval):
1836
+ downsampled_waypoints.append(last_waypoint + move_vector * (j + 1))
1837
+ downsampled_waypoints.append(now_waypoint)
1838
+ last_waypoint = now_waypoint
1839
+ return downsampled_waypoints
1840
+
1841
 
1842
 
 
 
 
 
1843
 
1844
+ # ==============================================================================
1845
+ # Gradio Application Logic
1846
+ # ==============================================================================
1847
+
1848
+ # --- Load the Model (do this once globally) ---
1849
+ print("Loading the Interfuser model...")
1850
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1851
+ model = interfuser_baseline() # This function needs to be defined from your script
1852
+ # Ensure the model file is in the same directory in your Hugging Face Space
1853
+ model_path = "interfuser_best_model.pth"
1854
+ if not os.path.exists(model_path):
1855
+ raise FileNotFoundError(f"Model file not found at {model_path}. Please upload it to the Space.")
1856
+
1857
+ state_dic = torch.load(model_path, map_location=device, weights_only=True)
1858
+ model.load_state_dict(state_dic)
1859
+ model.to(device)
1860
+ model.eval()
1861
+ print("Model loaded successfully.")
1862
+
1863
+ def run_single_frame(
1864
+ rgb_image_path: str,
1865
+ rgb_left_image_path: str,
1866
+ rgb_right_image_path: str,
1867
+ rgb_center_image_path: str,
1868
+ lidar_image_path: str,
1869
+ measurements_path: str,
1870
+ target_point_list: list
1871
+ ):
1872
+ """
1873
+ تعالج إطارًا واحدًا من البيانات، وتُنشئ لوحة تحكم مرئية كاملة،
1874
+ وتُرجع كلاً من الصورة والبيانات المهيكلة.
1875
+ """
1876
  try:
1877
+ # ==========================================================
1878
+ # 1. قراءة ومعالجة المدخلات من المسارات
1879
+ # ==========================================================
1880
+ if not rgb_image_path:
1881
+ raise gr.Error("الرجاء توفير مسار الصورة الأمامية (RGB).")
1882
+ # --- أ. قراءة الصور ---
1883
+ # <<< تصحيح: استخدام .name لقراءة الملفات >>>
1884
+ rgb_image_pil = Image.open(rgb_image_path.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1885
 
1886
+ # إذا كانت مسارات الصور الأخرى غير متوفرة، استخدم الصورة الأمامية
1887
+ # نتحقق من وجود الكائن نفسه ثم نستخدم .name
1888
+ rgb_left_pil = Image.open(rgb_left_image_path.name) if rgb_left_image_path else rgb_image_pil
1889
+ rgb_right_pil = Image.open(rgb_right_image_path.name) if rgb_right_image_path else rgb_image_pil
1890
+ rgb_center_pil = Image.open(rgb_center_image_path.name) if rgb_center_image_path else rgb_image_pil
1891
+
1892
+ # --- ب. قراءة ومعالجة الليدار ---
1893
+ if lidar_image_path:
1894
+ # <<< تصحيح: استخدام .name لقراءة ملف .npy >>>
1895
+ lidar_array = np.load(lidar_image_path.name)
1896
+ if lidar_array.max() > 0:
1897
+ lidar_array = (lidar_array / lidar_array.max()) * 255.0
1898
+ lidar_pil = Image.fromarray(lidar_array.astype(np.uint8))
1899
+ lidar_image_pil = lidar_pil.convert('RGB') if lidar_pil.mode != 'RGB' else lidar_pil
1900
+ else:
1901
+ lidar_image_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
1902
 
1903
+ # --- ج. تحويل الصور إلى تنسورات ---
1904
+ rgb_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
1905
+ rgb_left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
1906
+ rgb_right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
1907
+ rgb_center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
1908
+ lidar_tensor = lidar_transform(lidar_image_pil).unsqueeze(0).to(device)
1909
+
1910
+ # --- د. قراءة البيانات الرقمية ---
1911
+ # <<< تصحيح: استخدام .name لقراءة ملف JSON >>>
1912
+ with open(measurements_path.name, 'r') as f:
1913
+ measurements_dict = json.load(f)
1914
+
1915
+ measurements_values = [
1916
+ measurements_dict.get('x', 0.0), measurements_dict.get('y', 0.0),
1917
+ measurements_dict.get('theta', 0.0), measurements_dict.get('speed', 5.0),
1918
+ measurements_dict.get('steer', 0.0), measurements_dict.get('throttle', 0.0),
1919
+ measurements_dict.get('brake', 0.0), measurements_dict.get('command', 2.0),
1920
+ measurements_dict.get('is_junction', 0.0), measurements_dict.get('should_brake', 0.0)
1921
+ ]
1922
+ measurements_tensor = torch.tensor([measurements_values], dtype=torch.float32).to(device)
1923
+ target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
1924
+
1925
+ inputs = {
1926
+ 'rgb': rgb_tensor, 'rgb_left': rgb_left_tensor, 'rgb_right': rgb_right_tensor,
1927
+ 'rgb_center': rgb_center_tensor, 'lidar': lidar_tensor,
1928
+ 'measurements': measurements_tensor, 'target_point': target_point_tensor
1929
+ }
1930
+
1931
+ # ==========================================================
1932
+ # 2. تشغيل النموذج والمعالجات اللاحقة
1933
+ # ==========================================================
1934
+ with torch.no_grad():
1935
+ outputs = model(inputs)
1936
+ traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
1937
+
1938
+ measurements_np = measurements_tensor[0].cpu().numpy()
1939
+ pos, theta, speed = measurements_np[:2], measurements_np[2], measurements_np[3]
1940
+
1941
+ traffic_np = traffic[0].detach().cpu().numpy().reshape(20, 20, -1)
1942
+ waypoints_np = waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
1943
+
1944
+ tracker = Tracker()
1945
+ updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, frame_num=0)
1946
+
1947
+ controller = InterfuserController(ControllerConfig())
1948
+ steer, throttle, brake, metadata_tuple = controller.run_step(
1949
+ speed=speed, waypoints=waypoints_np, junction=is_junction.sigmoid()[0, 1].item(),
1950
+ traffic_light_state=traffic_light.sigmoid()[0, 0].item(),
1951
+ stop_sign=stop_sign.sigmoid()[0, 1].item(), meta_data=updated_traffic
1952
  )
1953
+
1954
+ # ==========================================================
1955
+ # 3. إنشاء التصور المرئي (Dashboard)
1956
+ # ==========================================================
1957
+ map_t0, counts_t0 = render(updated_traffic, t=0)
1958
+ map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
1959
+ map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
1960
 
1961
+ wp_map = render_waypoints(waypoints_np)
1962
+ self_car_map = render_self_car(loc=np.array([0,0]), ori=[math.cos(0), math.sin(0)], box=[4.0, 2.0])
1963
+
1964
+ map_t0 = cv2.add(cv2.add(map_t0, wp_map), self_car_map)
1965
+ map_t0 = cv2.resize(map_t0, (400, 400))
1966
+ map_t1 = cv2.add(ensure_rgb(map_t1), ensure_rgb(self_car_map)); map_t1 = cv2.resize(map_t1, (200, 200))
1967
+ map_t2 = cv2.add(ensure_rgb(map_t2), ensure_rgb(self_car_map)); map_t2 = cv2.resize(map_t2, (200, 200))
1968
+
1969
+ display = DisplayInterface()
1970
+ light_state = "Red" if traffic_light.sigmoid()[0,0].item() > 0.5 else "Green"
1971
+ stop_sign_state = "Yes" if stop_sign.sigmoid()[0,1].item() > 0.5 else "No"
1972
+
1973
+ interface_data = {
1974
+ 'camera_view': np.array(rgb_image_pil),
1975
+ 'map_t0': map_t0, 'map_t1': map_t1, 'map_t2': map_t2,
1976
+ 'text_info': {
1977
+ 'Frame': 'API Frame', 'Control': f"S:{steer:.2f} T:{throttle:.2f} B:{int(brake)}",
1978
+ 'Light': f"L: {light_state}", 'Stop': f"St: {stop_sign_state}"
1979
+ },
1980
+ 'object_counts': {'t0': counts_t0, 't1': counts_t1, 't2': counts_t2}
1981
+ }
1982
+
1983
+ dashboard_image = display.run_interface(interface_data)
1984
+
1985
+ # ==========================================================
1986
+ # 4. تجهيز وإرجاع المخرجات النهائية
1987
+ # ==========================================================
1988
+ result_dict = {
1989
+ "predicted_waypoints": waypoints_np.tolist(),
1990
+ "control_commands": {"steer": steer, "throttle": throttle, "brake": bool(brake)},
1991
+ "perception": {"traffic_light_status": light_state, "stop_sign_detected": (stop_sign_state == "Yes"), "is_at_junction_prob": round(is_junction.sigmoid()[0,1].item(), 3)},
1992
+ "metadata": {"speed_info": metadata_tuple[0], "perception_info": metadata_tuple[1], "stop_info": metadata_tuple[2], "safe_distance": metadata_tuple[3]}
1993
+ }
1994
+
1995
+ # تحويل صورة numpy إلى PIL Image قبل إرجاعها
1996
+ return Image.fromarray(dashboard_image), result_dict
1997
+
1998
+ except Exception as e:
1999
+ print(traceback.format_exc())
2000
+ raise gr.Error(f"Error processing single frame: {e}")
2001
+
2002
+ with gr.Blocks() as demo:
2003
+ gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
2004
+
2005
+ with gr.Tabs():
2006
+ with gr.TabItem("نقطة نهاية API (إطار واحد)", id=1):
2007
+ gr.Markdown("### اختبار النموذج بإدخال مباشر (Single Frame Inference)")
2008
+ gr.Markdown("هذه الواجهة مخصصة للمطورين. قم برفع الملفات المطلوبة لتشغيل النموذج على إطار واحد.")
2009
+
2010
+ with gr.Row():
2011
+ with gr.Column(scale=1):
2012
+ gr.Markdown("#### ملفات الصور")
2013
+ api_rgb_image_path = gr.File(label="RGB (Front) File (.jpg, .png)")
2014
+ api_rgb_left_image_path = gr.File(label="RGB (Left) File (Optional)")
2015
+ api_rgb_right_image_path = gr.File(label="RGB (Right) File (Optional)")
2016
+ api_rgb_center_image_path = gr.File(label="RGB (Center) File (Optional)")
2017
+ api_lidar_image_path = gr.File(label="LiDAR File (.npy, Optional)")
2018
+
2019
+ with gr.Column(scale=2):
2020
+ gr.Markdown("#### ملفات ومحتويات البيانات")
2021
+ api_measurements_path = gr.File(label="Measurements File (.json)")
2022
+ api_target_point_list = gr.JSON(label="Target Point (List [x, y])", value=[0.0, 100.0])
2023
+ api_output_image = gr.Image(label="Dashboard Result", type="pil")
2024
+ api_output_json = gr.JSON(label="نتائج النموذج (JSON)")
2025
+ gr.Markdown("---")
2026
+ api_run_button = gr.Button("🚀 تشغيل إطار واحد", variant="primary")
2027
+ gr.Markdown("---")
2028
+ gr.Markdown("#### المخرجات")
2029
+
2030
+ # <<< بداية التعديل: إضافة مكون لعرض الصورة الناتجة >>>
2031
+ with gr.Row():
2032
+ # سيتم عرض لوحة التحكم هنا
2033
+ api_output_image = gr.Image(label="Dashboard Result", type="pil")
2034
+ # سيتم عرض بيانات JSON هنا
2035
+ api_output_json = gr.JSON(label="نتائج النموذج (JSON)")
2036
+
2037
+ api_run_button.click(
2038
+ fn=run_single_frame,
2039
+ inputs=[
2040
+ api_rgb_image_path,
2041
+ api_rgb_left_image_path,
2042
+ api_rgb_right_image_path,
2043
+ api_rgb_center_image_path,
2044
+ api_lidar_image_path,
2045
+ api_measurements_path,
2046
+ api_target_point_list
2047
+ ],
2048
+ # الآن نربط المخرجين اللذين تُرجعهما الدالة بالمكونين الصحيحين
2049
+ outputs=[api_output_image, api_output_json],
2050
+ api_name="run_single_frame"
2051
+ )
2052
 
2053
+ # ==============================================================================
2054
+ # 7. تشغيل التطبيق
2055
+ # ==============================================================================
2056
  if __name__ == "__main__":
2057
+ demo.queue().launch(debug=True)