mohammed-aljafry commited on
Commit
5a6c071
·
verified ·
1 Parent(s): 6143323

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +157 -83
app.py CHANGED
@@ -10,8 +10,10 @@ from PIL import Image
10
  import cv2
11
  import math
12
 
13
- # --- استيراد من الملفات التي أنشأناها ---
14
- from model import interfuser_baseline
 
 
15
  from logic import (
16
  transform, lidar_transform, InterfuserController, ControllerConfig,
17
  Tracker, DisplayInterface, render, render_waypoints, render_self_car,
@@ -19,97 +21,158 @@ from logic import (
19
  )
20
 
21
  # ==============================================================================
22
- # 1. تحميل النموذج (يتم مرة واحدة)
23
  # ==============================================================================
24
- print("Loading the Interfuser model...")
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- model = interfuser_baseline()
27
- model_path = "model/interfuser_best_model.pth"
28
- if not os.path.exists(model_path):
29
- raise FileNotFoundError(f"Model file not found at {model_path}. Please upload it.")
30
-
31
- # استخدام weights_only=True لزيادة الأمان عند تحميل الملفات من مصادر غير موثوقة
32
- try:
33
- state_dic = torch.load(model_path, map_location=device, weights_only=True)
34
- except:
35
- state_dic = torch.load(model_path, map_location=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- model.load_state_dict(state_dic)
38
- model.to(device)
39
- model.eval()
40
- print("Model loaded successfully.")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # ==============================================================================
44
- # 2. دالة التشغيل الرئيسية لـ Gradio
45
  # ==============================================================================
46
  def run_single_frame(
47
- rgb_image_path: str,
48
- rgb_left_image_path: str,
49
- rgb_right_image_path: str,
50
- rgb_center_image_path: str,
51
- lidar_image_path: str,
52
- measurements_path: str,
53
- target_point_list: list
54
  ):
55
- """
56
- تعالج إطارًا واحدًا من البيانات، وتُنشئ لوحة تحكم مرئية كاملة،
57
- وتُرجع كلاً من الصورة والبيانات المهيكلة.
58
- """
 
59
  try:
60
- # ==========================================================
61
- # 1. قراءة ومعالجة المدخلات من المسارات
62
- # ==========================================================
63
  if not rgb_image_path:
64
  raise gr.Error("الرجاء توفير مسار الصورة الأمامية (RGB).")
65
-
66
  rgb_image_pil = Image.open(rgb_image_path.name).convert("RGB")
67
  rgb_left_pil = Image.open(rgb_left_image_path.name).convert("RGB") if rgb_left_image_path else rgb_image_pil
68
  rgb_right_pil = Image.open(rgb_right_image_path.name).convert("RGB") if rgb_right_image_path else rgb_image_pil
69
  rgb_center_pil = Image.open(rgb_center_image_path.name).convert("RGB") if rgb_center_image_path else rgb_image_pil
70
 
 
 
 
 
 
 
71
  if lidar_image_path:
72
  lidar_array = np.load(lidar_image_path.name)
73
  if lidar_array.max() > 0:
74
  lidar_array = (lidar_array / lidar_array.max()) * 255.0
75
- lidar_pil = Image.fromarray(lidar_array.astype(np.uint8))
76
- lidar_image_pil = lidar_pil.convert('RGB')
77
  else:
78
- lidar_image_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
79
-
80
- rgb_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
81
- rgb_left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
82
- rgb_right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
83
- rgb_center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
84
- lidar_tensor = lidar_transform(lidar_image_pil).unsqueeze(0).to(device)
85
 
86
  with open(measurements_path.name, 'r') as f:
87
- measurements_dict = json.load(f)
88
-
89
- measurements_values = [
90
- measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0),
91
- measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0),
92
- measurements_dict.get('command', 2.0), measurements_dict.get('command', 2.0),
93
- measurements_dict.get('speed', 5.0)
94
- ]
95
- measurements_tensor = torch.tensor([measurements_values], dtype=torch.float32).to(device)
96
- target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
97
 
 
 
 
 
 
 
 
 
 
 
 
98
  inputs = {
99
- 'rgb': rgb_tensor, 'rgb_left': rgb_left_tensor, 'rgb_right': rgb_right_tensor,
100
- 'rgb_center': rgb_center_tensor, 'lidar': lidar_tensor,
101
- 'measurements': measurements_tensor, 'target_point': target_point_tensor
 
 
 
 
102
  }
103
 
104
- # ==========================================================
105
- # 2. تشغيل النموذج والمعالجات اللاحقة
106
- # ==========================================================
107
  with torch.no_grad():
108
- outputs = model(inputs)
109
  traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
110
 
111
- measurements_np = measurements_tensor[0].cpu().numpy()
112
- pos, theta, speed = [0,0], 0, measurements_np[6]
 
113
 
114
  traffic_np = traffic[0].detach().cpu().numpy().reshape(20, 20, -1)
115
  waypoints_np = waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
@@ -118,15 +181,13 @@ def run_single_frame(
118
  updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, frame_num=0)
119
 
120
  controller = InterfuserController(ControllerConfig())
121
- steer, throttle, brake, metadata_tuple = controller.run_step(
122
  speed=speed, waypoints=waypoints_np, junction=is_junction.sigmoid()[0, 1].item(),
123
  traffic_light_state=traffic_light.sigmoid()[0, 0].item(),
124
  stop_sign=stop_sign.sigmoid()[0, 1].item(), meta_data=updated_traffic
125
  )
126
 
127
- # ==========================================================
128
- # 3. إنشاء التصور المرئي (Dashboard)
129
- # ==========================================================
130
  map_t0, counts_t0 = render(updated_traffic, t=0)
131
  map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
132
  map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
@@ -144,8 +205,7 @@ def run_single_frame(
144
  stop_sign_state = "Yes" if stop_sign.sigmoid()[0,1].item() > 0.5 else "No"
145
 
146
  interface_data = {
147
- 'camera_view': np.array(rgb_image_pil),
148
- 'map_t0': map_t0, 'map_t1': map_t1, 'map_t2': map_t2,
149
  'text_info': {
150
  'Frame': 'API Frame', 'Control': f"S:{steer:.2f} T:{throttle:.2f} B:{int(brake)}",
151
  'Light': f"L: {light_state}", 'Stop': f"St: {stop_sign_state}"
@@ -155,37 +215,52 @@ def run_single_frame(
155
 
156
  dashboard_image = display.run_interface(interface_data)
157
 
158
- # ==========================================================
159
- # 4. تجهيز وإرجاع المخرجات النهائية
160
- # ==========================================================
161
  result_dict = {
162
  "predicted_waypoints": waypoints_np.tolist(),
163
  "control_commands": {"steer": steer, "throttle": throttle, "brake": bool(brake)},
164
  "perception": {"traffic_light_status": light_state, "stop_sign_detected": (stop_sign_state == "Yes"), "is_at_junction_prob": round(is_junction.sigmoid()[0,1].item(), 3)},
165
- "metadata": {"speed_info": metadata_tuple[0], "perception_info": metadata_tuple[1], "stop_info": metadata_tuple[2], "safe_distance": metadata_tuple[3]}
166
  }
167
 
168
  return Image.fromarray(dashboard_image), result_dict
169
 
170
  except Exception as e:
171
  print(traceback.format_exc())
172
- raise gr.Error(f"Error processing single frame: {e}")
173
-
174
 
175
  # ==============================================================================
176
  # 4. تعريف واجهة Gradio
177
  # ==============================================================================
 
 
 
 
178
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
179
  gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  with gr.Tabs():
182
  with gr.TabItem("نقطة نهاية API (إطار واحد)", id=1):
183
  gr.Markdown("### اختبار النموذج بإدخال مباشر (Single Frame Inference)")
184
- gr.Markdown("هذه الواجهة مخصصة للمطورين. قم برفع الملفات المطلوبة لتشغيل النموذج على إطار واحد.")
185
 
186
  with gr.Row():
187
  with gr.Column(scale=1):
188
- gr.Markdown("#### ملفات الصور والبيانات")
189
  api_rgb_image_path = gr.File(label="RGB (Front) File (.jpg, .png)")
190
  api_rgb_left_image_path = gr.File(label="RGB (Left) File (Optional)")
191
  api_rgb_right_image_path = gr.File(label="RGB (Right) File (Optional)")
@@ -203,13 +278,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
203
  api_run_button.click(
204
  fn=run_single_frame,
205
  inputs=[
206
- api_rgb_image_path,
207
- api_rgb_left_image_path,
208
- api_rgb_right_image_path,
209
- api_rgb_center_image_path,
210
- api_lidar_image_path,
211
- api_measurements_path,
212
- api_target_point_list
213
  ],
214
  outputs=[api_output_image, api_output_json],
215
  api_name="run_single_frame"
@@ -219,4 +290,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
219
  # 5. تشغيل التطبيق
220
  # ==============================================================================
221
  if __name__ == "__main__":
 
 
 
222
  demo.queue().launch(debug=True)
 
10
  import cv2
11
  import math
12
 
13
+ # --- استيراد من الملفات المنظمة في مشروعك ---
14
+ # نفترض أن بنية النموذج موجودة في model/architecture.py
15
+ from model import build_interfuser_model
16
+ # نفترض أن بقية المنطق موجود في logic.py
17
  from logic import (
18
  transform, lidar_transform, InterfuserController, ControllerConfig,
19
  Tracker, DisplayInterface, render, render_waypoints, render_self_car,
 
21
  )
22
 
23
  # ==============================================================================
24
+ # 1. إعدادات ومسارات النماذج
25
  # ==============================================================================
26
+ WEIGHTS_DIR = "model"
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ # متغير عام لتخزين النموذج المحمّل حاليًا
30
+ current_model = None
31
+
32
+ # قاموس لتحديد الإعدادات الخاصة بكل نموذج.
33
+ # اسم المفتاح يجب أن يطابق اسم ملف الأوزان دون .pth).
34
+ # إذا لم يتم تحديد إعدادات لنموذج ما، سيتم استخدام الإعدادات الافتراضية في دالة البناء.
35
+ MODELS_SPECIFIC_CONFIGS = {
36
+ "interfuser_baseline": {
37
+ "rgb_backbone_name": "r50",
38
+ "embed_dim": 256,
39
+ "direct_concat": True, # هذا النموذج يتوقع دمج الصور
40
+ },
41
+ "interfuser_lightweight": {
42
+ "rgb_backbone_name": "r26",
43
+ "embed_dim": 128,
44
+ "enc_depth": 4,
45
+ "dec_depth": 4,
46
+ "direct_concat": True, # هذا النموذج يتوقع دمج الصور
47
+ }
48
+ # أضف هنا أي إعدادات لنماذج أخرى لديك
49
+ # "my_other_model": { "direct_concat": False, ... }
50
+ }
51
+
52
+ def find_available_models():
53
+ """
54
+ تبحث في مجلد الأوزان وتعيد قائمة بأسماء النماذج المتاحة.
55
+ """
56
+ if not os.path.isdir(WEIGHTS_DIR):
57
+ print(f"تحذير: مجلد الأوزان '{WEIGHTS_DIR}' غير موجود.")
58
+ return []
59
+
60
+ models = [f.replace(".pth", "") for f in os.listdir(WEIGHTS_DIR) if f.endswith(".pth")]
61
+ return models
62
+
63
+ # ==============================================================================
64
+ # 2. دالة تحميل النموذج الديناميكية
65
+ # ==============================================================================
66
+ def load_model(model_name: str):
67
+ """
68
+ تحمل النموذج المحدد من القائمة المنسدلة وتضعه في المتغير العام current_model.
69
+ """
70
+ global current_model
71
 
72
+ if not model_name:
73
+ return "الرجاء اختيار نموذج من القائمة."
 
 
74
 
75
+ weights_path = os.path.join(WEIGHTS_DIR, f"{model_name}.pth")
76
+ print(f"Attempting to load model: '{model_name}' from '{weights_path}'")
77
+
78
+ # الحصول على الإعدادات المخصصة للنموذج، أو قاموس فارغ إذا لم توجد
79
+ model_config = MODELS_SPECIFIC_CONFIGS.get(model_name, {})
80
+
81
+ # بناء النموذج باستخدام الإعدادات المحددة
82
+ model = build_interfuser_model(model_config)
83
+
84
+ if not os.path.exists(weights_path):
85
+ gr.Warning(f"ملف الأوزان '{weights_path}' غير موجود. سيتم استخدام النموذج بأوزان عشوائية.")
86
+ else:
87
+ try:
88
+ # استخدام weights_only=True للأمان
89
+ state_dic = torch.load(weights_path, map_location=device, weights_only=True)
90
+ model.load_state_dict(state_dic)
91
+ print(f"تم تحميل أوزان النموذج '{model_name}' بنجاح.")
92
+ except Exception as e:
93
+ gr.Warning(f"فشل تحميل الأوزان للنموذج '{model_name}': {e}. تأكد من تطابق الإعدادات في 'MODELS_SPECIFIC_CONFIGS' مع الملف المحفوظ. سيتم استخدام أوزان عشوائية.")
94
+
95
+ model.to(device)
96
+ model.eval()
97
+
98
+ current_model = model # تحديث النموذج العام
99
+
100
+ return f"تم تحميل نموذج: {model_name}"
101
 
102
  # ==============================================================================
103
+ # 3. دالة التشغيل الرئيسية لـ Gradio
104
  # ==============================================================================
105
  def run_single_frame(
106
+ rgb_image_path,
107
+ rgb_left_image_path,
108
+ rgb_right_image_path,
109
+ rgb_center_image_path,
110
+ lidar_image_path,
111
+ measurements_path,
112
+ target_point_list
113
  ):
114
+ global current_model
115
+
116
+ if current_model is None:
117
+ raise gr.Error("الرجاء اختيار وتحميل نموذج أولاً من القائمة المنسدلة.")
118
+
119
  try:
120
+ # --- 1. قراءة ومعالجة المدخلات ---
 
 
121
  if not rgb_image_path:
122
  raise gr.Error("الرجاء توفير مسار الصورة الأمامية (RGB).")
123
+
124
  rgb_image_pil = Image.open(rgb_image_path.name).convert("RGB")
125
  rgb_left_pil = Image.open(rgb_left_image_path.name).convert("RGB") if rgb_left_image_path else rgb_image_pil
126
  rgb_right_pil = Image.open(rgb_right_image_path.name).convert("RGB") if rgb_right_image_path else rgb_image_pil
127
  rgb_center_pil = Image.open(rgb_center_image_path.name).convert("RGB") if rgb_center_image_path else rgb_image_pil
128
 
129
+ # تطبيق التحويلات لتحويل الصور إلى تنسورات
130
+ front_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
131
+ left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
132
+ right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
133
+ center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
134
+
135
  if lidar_image_path:
136
  lidar_array = np.load(lidar_image_path.name)
137
  if lidar_array.max() > 0:
138
  lidar_array = (lidar_array / lidar_array.max()) * 255.0
139
+ lidar_pil = Image.fromarray(lidar_array.astype(np.uint8)).convert('RGB')
 
140
  else:
141
+ lidar_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
142
+ lidar_tensor = lidar_transform(lidar_pil).unsqueeze(0).to(device)
 
 
 
 
 
143
 
144
  with open(measurements_path.name, 'r') as f:
145
+ m_dict = json.load(f)
 
 
 
 
 
 
 
 
 
146
 
147
+ # إنشاء تنسور القياسات الصحيح (10 عناصر)
148
+ measurements_tensor = torch.tensor([[
149
+ m_dict.get('x', 0.0), m_dict.get('y', 0.0), m_dict.get('theta', 0.0),
150
+ m_dict.get('speed', 5.0), m_dict.get('steer', 0.0), m_dict.get('throttle', 0.0),
151
+ float(m_dict.get('brake', 0.0)), m_dict.get('command', 2.0),
152
+ float(m_dict.get('is_junction', 0.0)), float(m_dict.get('should_brake', 0.0))
153
+ ]], dtype=torch.float32).to(device)
154
+
155
+ target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
156
+
157
+ # تجميع المدخلات للنموذج
158
  inputs = {
159
+ 'rgb': front_tensor, # للنماذج التي لا تدمج
160
+ 'rgb_left': left_tensor,
161
+ 'rgb_right': right_tensor,
162
+ 'rgb_center': center_tensor,
163
+ 'lidar': lidar_tensor,
164
+ 'measurements': measurements_tensor,
165
+ 'target_point': target_point_tensor
166
  }
167
 
168
+ # --- 2. تشغيل النموذج ---
 
 
169
  with torch.no_grad():
170
+ outputs = current_model(inputs)
171
  traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
172
 
173
+ # --- 3. المعالجة اللاحقة والتصوّر ---
174
+ speed = m_dict.get('speed', 5.0)
175
+ pos, theta = [m_dict.get('x', 0.0), m_dict.get('y', 0.0)], m_dict.get('theta', 0.0)
176
 
177
  traffic_np = traffic[0].detach().cpu().numpy().reshape(20, 20, -1)
178
  waypoints_np = waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
 
181
  updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, frame_num=0)
182
 
183
  controller = InterfuserController(ControllerConfig())
184
+ steer, throttle, brake, metadata = controller.run_step(
185
  speed=speed, waypoints=waypoints_np, junction=is_junction.sigmoid()[0, 1].item(),
186
  traffic_light_state=traffic_light.sigmoid()[0, 0].item(),
187
  stop_sign=stop_sign.sigmoid()[0, 1].item(), meta_data=updated_traffic
188
  )
189
 
190
+ # إنشاء لوحة التحكم المرئية
 
 
191
  map_t0, counts_t0 = render(updated_traffic, t=0)
192
  map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
193
  map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
 
205
  stop_sign_state = "Yes" if stop_sign.sigmoid()[0,1].item() > 0.5 else "No"
206
 
207
  interface_data = {
208
+ 'camera_view': np.array(rgb_image_pil), 'map_t0': map_t0, 'map_t1': map_t1, 'map_t2': map_t2,
 
209
  'text_info': {
210
  'Frame': 'API Frame', 'Control': f"S:{steer:.2f} T:{throttle:.2f} B:{int(brake)}",
211
  'Light': f"L: {light_state}", 'Stop': f"St: {stop_sign_state}"
 
215
 
216
  dashboard_image = display.run_interface(interface_data)
217
 
218
+ # --- 4. تجهيز المخرجات ---
 
 
219
  result_dict = {
220
  "predicted_waypoints": waypoints_np.tolist(),
221
  "control_commands": {"steer": steer, "throttle": throttle, "brake": bool(brake)},
222
  "perception": {"traffic_light_status": light_state, "stop_sign_detected": (stop_sign_state == "Yes"), "is_at_junction_prob": round(is_junction.sigmoid()[0,1].item(), 3)},
223
+ "metadata": {"speed_info": metadata[0], "perception_info": metadata[1], "stop_info": metadata[2], "safe_distance": metadata[3]}
224
  }
225
 
226
  return Image.fromarray(dashboard_image), result_dict
227
 
228
  except Exception as e:
229
  print(traceback.format_exc())
230
+ raise gr.Error(f"حدث خطأ أثناء معالجة الإطار: {e}")
 
231
 
232
  # ==============================================================================
233
  # 4. تعريف واجهة Gradio
234
  # ==============================================================================
235
+
236
+ # البحث عن النماذج المتاحة عند بدء تشغيل الواجهة
237
+ available_models = find_available_models()
238
+
239
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
240
  gr.Markdown("# 🚗 محاكاة القيادة الذاتية باستخدام Interfuser")
241
 
242
+ with gr.Row():
243
+ model_selector = gr.Dropdown(
244
+ label="اختر النموذج من مجلد 'model/weights'",
245
+ choices=available_models,
246
+ value=available_models[0] if available_models else "لم يتم العثور على نماذج"
247
+ )
248
+ status_textbox = gr.Textbox(label="حالة تحميل النموذج", interactive=False)
249
+
250
+ # التحميل الأولي والتحميل عند التغيير
251
+ if available_models:
252
+ demo.load(fn=load_model, inputs=model_selector, outputs=status_textbox)
253
+ model_selector.change(fn=load_model, inputs=model_selector, outputs=status_textbox)
254
+
255
+ gr.Markdown("---")
256
+
257
  with gr.Tabs():
258
  with gr.TabItem("نقطة نهاية API (إطار واحد)", id=1):
259
  gr.Markdown("### اختبار النموذج بإدخال مباشر (Single Frame Inference)")
 
260
 
261
  with gr.Row():
262
  with gr.Column(scale=1):
263
+ gr.Markdown("#### المدخلات")
264
  api_rgb_image_path = gr.File(label="RGB (Front) File (.jpg, .png)")
265
  api_rgb_left_image_path = gr.File(label="RGB (Left) File (Optional)")
266
  api_rgb_right_image_path = gr.File(label="RGB (Right) File (Optional)")
 
278
  api_run_button.click(
279
  fn=run_single_frame,
280
  inputs=[
281
+ api_rgb_image_path, api_rgb_left_image_path, api_rgb_right_image_path,
282
+ api_rgb_center_image_path, api_lidar_image_path,
283
+ api_measurements_path, api_target_point_list
 
 
 
 
284
  ],
285
  outputs=[api_output_image, api_output_json],
286
  api_name="run_single_frame"
 
290
  # 5. تشغيل التطبيق
291
  # ==============================================================================
292
  if __name__ == "__main__":
293
+ if not available_models:
294
+ print("تحذير: لم يتم العثور على أي ملفات نماذج (.pth) في مجلد 'model/weights'.")
295
+ print("سيتم تشغيل الواجهة ولكن لن تتمكن من تحميل أي نموذج.")
296
  demo.queue().launch(debug=True)