mohammed-aljafry commited on
Commit
d962918
·
verified ·
1 Parent(s): b598ef9

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +65 -74
app.py CHANGED
@@ -22,7 +22,7 @@ from logic import (
22
  # 1. إعدادات ومسارات النماذج
23
  # ==============================================================================
24
  WEIGHTS_DIR = "model"
25
- EXAMPLES_DIR = "examples" # مجلد جديد للأمثلة
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
 
28
  MODELS_SPECIFIC_CONFIGS = {
@@ -35,24 +35,18 @@ def find_available_models():
35
  return [f.replace(".pth", "") for f in os.listdir(WEIGHTS_DIR) if f.endswith(".pth")]
36
 
37
  # ==============================================================================
38
- # 2. الدوال الأساسية (load_model, run_single_frame)
39
  # ==============================================================================
 
40
  def load_model(model_name: str):
41
- """
42
- (لا تغيير في هذه الدالة)
43
- تبني وتحمل النموذج المختار وتُرجعه ككائن.
44
- """
45
  if not model_name or "لم يتم" in model_name:
46
  return None, "الرجاء اختيار نموذج صالح."
47
-
48
  weights_path = os.path.join(WEIGHTS_DIR, f"{model_name}.pth")
49
  print(f"Building model: '{model_name}'")
50
-
51
  model_config = MODELS_SPECIFIC_CONFIGS.get(model_name, {})
52
  model = build_interfuser_model(model_config)
53
-
54
  if not os.path.exists(weights_path):
55
- gr.Warning(f"ملف الأوزان '{weights_path}' غير موجود. النموذج سيعمل بأوزان عشوائية.")
56
  else:
57
  try:
58
  state_dic = torch.load(weights_path, map_location=device, weights_only=True)
@@ -60,69 +54,73 @@ def load_model(model_name: str):
60
  print(f"تم تحميل أوزان النموذج '{model_name}' بنجاح.")
61
  except Exception as e:
62
  gr.Warning(f"فشل تحميل الأوزان للنموذج '{model_name}': {e}.")
63
-
64
  model.to(device)
65
  model.eval()
66
-
67
  return model, f"تم تحميل نموذج: {model_name}"
68
 
69
 
70
  def run_single_frame(
71
- model_from_state, # المدخل من gr.State
72
- rgb_image_path,
73
- rgb_left_image_path,
74
- rgb_right_image_path,
75
- rgb_center_image_path,
76
- lidar_image_path,
77
- measurements_path,
78
- target_point_list
79
  ):
80
  """
81
- (تم تعديل هذه الدالة)
82
- تعالج إطارًا واحدًا، وتقوم بتحميل النموذج الافتراضي إذا لزم الأمر لجلسات الـ API.
83
  """
84
- # --- تعديل للتعامل مع جلسات الـ API ---
85
- # إذا كانت هذه جلسة API جديدة (model_state فارغ)، قم بتحميل النموذج الافتراضي
86
  if model_from_state is None:
87
  print("API session detected or model not loaded. Loading default model...")
88
  available_models = find_available_models()
89
- if not available_models:
90
- raise gr.Error("لا توجد نماذج متاحة للتحميل في مجلد 'model/weights'.")
91
-
92
- default_model_name = available_models[0]
93
- model_to_use, _ = load_model(default_model_name)
94
  else:
95
- # إذا كان النموذج محملًا بالفعل (من جلسة متصفح)، استخدمه مباشرة
96
  model_to_use = model_from_state
97
 
98
  if model_to_use is None:
99
- raise gr.Error("فشل تحميل النموذج. تحقق من السجلات (Logs) في Hugging Face Space.")
100
- # --- نهاية التعديل ---
101
 
102
  try:
103
- # --- 1. قراءة ومعالجة المدخلات ---
104
  if not (rgb_image_path and measurements_path):
105
  raise gr.Error("الرجاء توفير الصورة الأمامية وملف القياسات على الأقل.")
106
-
107
- rgb_image_pil = Image.open(rgb_image_path).convert("RGB")
108
- rgb_left_pil = Image.open(rgb_left_image_path).convert("RGB") if rgb_left_image_path else rgb_image_pil
109
- rgb_right_pil = Image.open(rgb_right_image_path).convert("RGB") if rgb_right_image_path else rgb_image_pil
110
- rgb_center_pil = Image.open(rgb_center_image_path).convert("RGB") if rgb_center_image_path else rgb_image_pil
111
 
112
- front_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
113
- left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
114
- right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
115
- center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  if lidar_image_path:
118
- lidar_array = np.load(lidar_image_path)
119
- if lidar_array.max() > 0: lidar_array = (lidar_array / lidar_array.max()) * 255.0
120
- lidar_pil = Image.fromarray(lidar_array.astype(np.uint8)).convert('RGB')
 
 
 
121
  else:
122
  lidar_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
123
- lidar_tensor = lidar_transform(lidar_pil).unsqueeze(0).to(device)
124
 
125
- with open(measurements_path, 'r') as f: m_dict = json.load(f)
 
 
 
 
 
 
 
 
 
 
126
 
127
  measurements_tensor = torch.tensor([[
128
  m_dict.get('x',0.0), m_dict.get('y',0.0), m_dict.get('theta',0.0), m_dict.get('speed',5.0),
@@ -132,24 +130,21 @@ def run_single_frame(
132
 
133
  target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
134
 
135
- inputs = {
136
- 'rgb': front_tensor, 'rgb_left': left_tensor, 'rgb_right': right_tensor,
137
- 'rgb_center': center_tensor, 'lidar': lidar_tensor,
138
- 'measurements': measurements_tensor, 'target_point': target_point_tensor
139
- }
140
 
141
- # --- 2. تشغيل النموذج ---
142
  with torch.no_grad():
143
- outputs = model_to_use(inputs) # <-- استخدام model_to_use
144
  traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
145
 
146
- # --- 3. المعالجة اللاحقة والتصوّر ---
147
  speed, pos, theta = m_dict.get('speed',5.0), [m_dict.get('x',0.0), m_dict.get('y',0.0)], m_dict.get('theta',0.0)
148
  traffic_np, waypoints_np = traffic[0].detach().cpu().numpy().reshape(20,20,-1), waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
149
  tracker, controller = Tracker(), InterfuserController(ControllerConfig())
150
  updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, 0)
151
  steer, throttle, brake, metadata = controller.run_step(speed, waypoints_np, is_junction.sigmoid()[0,1].item(), traffic_light.sigmoid()[0,0].item(), stop_sign.sigmoid()[0,1].item(), updated_traffic)
152
 
 
153
  map_t0, counts_t0 = render(updated_traffic, t=0)
154
  map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
155
  map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
@@ -166,20 +161,21 @@ def run_single_frame(
166
  'object_counts': {'t0': counts_t0,'t1': counts_t1,'t2': counts_t2}}
167
  dashboard_image = display.run_interface(interface_data)
168
 
169
- # --- 4. تجهيز المخرجات ---
170
- result_dict = {"predicted_waypoints": waypoints_np.tolist(), "control_commands": {"steer": steer,"throttle": throttle,"brake": bool(brake)},
171
- "perception": {"traffic_light_status": light_state,"stop_sign_detected": (stop_sign_state == "Yes"),"is_at_junction_prob": round(is_junction.sigmoid()[0,1].item(), 3)},
172
- "metadata": {"speed_info": metadata[0],"perception_info": metadata[1],"stop_info": metadata[2],"safe_distance": metadata[3]}}
173
-
174
- return Image.fromarray(dashboard_image), result_dict
175
  except Exception as e:
176
  print(traceback.format_exc())
177
- raise gr.Error(f"حدث خطأ أثناء معالجة الإطار: {e}")
178
 
179
 
180
  # ==============================================================================
181
- # 4. تعريف واجهة Gradio المحسّنة (مع الإصلاح)
182
  # ==============================================================================
 
183
  available_models = find_available_models()
184
 
185
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo:
@@ -191,7 +187,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
191
  with gr.Row():
192
  # -- العمود الأيسر: الإعدادات والمدخلات --
193
  with gr.Column(scale=1):
194
- # --- الخطوة 1: اختيار النموذج ---
195
  with gr.Group():
196
  gr.Markdown("## ⚙️ الخطوة 1: اختر النموذج")
197
  with gr.Row():
@@ -202,7 +197,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
202
  )
203
  status_textbox = gr.Textbox(label="حالة النموذج", interactive=False)
204
 
205
- # --- الخطوة 2: رفع ملفات السيناريو ---
206
  with gr.Group():
207
  gr.Markdown("## 🗂️ الخطوة 2: ارفع ملفات السيناريو")
208
 
@@ -221,16 +215,14 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
221
 
222
  api_run_button = gr.Button("🚀 شغل المحاكاة", variant="primary", scale=2)
223
 
224
- # --- أمثلة جاهزة ---
225
  with gr.Group():
226
  gr.Markdown("### ✨ أمثلة جاهزة")
227
- gr.Markdown("انقر على مثال لتعبئة الحقول تلقائياً (يتطلب وجود مجلد `examples` بنفس بنية البيانات).")
228
  gr.Examples(
229
  examples=[
230
- [os.path.join(EXAMPLES_DIR, "sample1", "rgb.jpg"), os.path.join(EXAMPLES_DIR, "sample1", "measurements.json")],
231
- [os.path.join(EXAMPLES_DIR, "sample2", "rgb.jpg"), os.path.join(EXAMPLES_DIR, "sample2", "measurements.json")]
232
  ],
233
- # يجب أن تتطابق المدخلات مع الحقول المطلوبة في الأمثلة
234
  inputs=[api_rgb_image_path, api_measurements_path],
235
  label="اختر سيناريو اختبار"
236
  )
@@ -240,8 +232,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
240
  with gr.Group():
241
  gr.Markdown("## 📊 الخطوة 3: شاهد النتائج")
242
  api_output_image = gr.Image(label="لوحة التحكم المرئية (Dashboard)", type="pil", interactive=False)
243
- with gr.Accordion("عرض نتائج JSON التفصيلية", open=False):
244
- api_output_json = gr.JSON(label="النتائج المهيكلة (JSON)")
245
 
246
  # --- ربط منطق الواجهة ---
247
  if available_models:
@@ -253,12 +244,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), cs
253
  fn=run_single_frame,
254
  inputs=[model_state, api_rgb_image_path, api_rgb_left_image_path, api_rgb_right_image_path,
255
  api_rgb_center_image_path, api_lidar_image_path, api_measurements_path, api_target_point_list],
256
- outputs=[api_output_image, api_output_json],
257
  api_name="run_single_frame"
258
  )
259
 
260
  # ==============================================================================
261
- # 5. تشغيل التطبيق
262
  # ==============================================================================
263
  if __name__ == "__main__":
264
  if not available_models:
 
22
  # 1. إعدادات ومسارات النماذج
23
  # ==============================================================================
24
  WEIGHTS_DIR = "model"
25
+ EXAMPLES_DIR = "examples"
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
 
28
  MODELS_SPECIFIC_CONFIGS = {
 
35
  return [f.replace(".pth", "") for f in os.listdir(WEIGHTS_DIR) if f.endswith(".pth")]
36
 
37
  # ==============================================================================
38
+ # 2. الدوال الأساسية
39
  # ==============================================================================
40
+
41
  def load_model(model_name: str):
 
 
 
 
42
  if not model_name or "لم يتم" in model_name:
43
  return None, "الرجاء اختيار نموذج صالح."
 
44
  weights_path = os.path.join(WEIGHTS_DIR, f"{model_name}.pth")
45
  print(f"Building model: '{model_name}'")
 
46
  model_config = MODELS_SPECIFIC_CONFIGS.get(model_name, {})
47
  model = build_interfuser_model(model_config)
 
48
  if not os.path.exists(weights_path):
49
+ gr.Warning(f"ملف الأوزان '{weights_path}' غير موجود.")
50
  else:
51
  try:
52
  state_dic = torch.load(weights_path, map_location=device, weights_only=True)
 
54
  print(f"تم تحميل أوزان النموذج '{model_name}' بنجاح.")
55
  except Exception as e:
56
  gr.Warning(f"فشل تحميل الأوزان للنموذج '{model_name}': {e}.")
 
57
  model.to(device)
58
  model.eval()
 
59
  return model, f"تم تحميل نموذج: {model_name}"
60
 
61
 
62
  def run_single_frame(
63
+ model_from_state, rgb_image_path, rgb_left_image_path, rgb_right_image_path,
64
+ rgb_center_image_path, lidar_image_path, measurements_path, target_point_list
 
 
 
 
 
 
65
  ):
66
  """
67
+ (نسخة أكثر قوة مع معالجة أخطاء مفصلة)
 
68
  """
 
 
69
  if model_from_state is None:
70
  print("API session detected or model not loaded. Loading default model...")
71
  available_models = find_available_models()
72
+ if not available_models: raise gr.Error("لا توجد نماذج متاحة للتحميل.")
73
+ model_to_use, _ = load_model(available_models[0])
 
 
 
74
  else:
 
75
  model_to_use = model_from_state
76
 
77
  if model_to_use is None:
78
+ raise gr.Error("فشل تحميل النموذج. تحقق من السجلات (Logs).")
 
79
 
80
  try:
81
+ # --- 1. التحقق من المدخلات المطلوبة ---
82
  if not (rgb_image_path and measurements_path):
83
  raise gr.Error("الرجاء توفير الصورة الأمامية وملف القياسات على الأقل.")
 
 
 
 
 
84
 
85
+ # --- 2. قراءة ومعالجة المدخلات مع معالجة أخطاء مفصلة ---
86
+ try:
87
+ rgb_image_pil = Image.open(rgb_image_path).convert("RGB")
88
+ except Exception as e:
89
+ raise gr.Error(f"فشل تحميل صورة الكاميرا الأمامية. تأكد من أن الملف صحيح. الخطأ: {e}")
90
+
91
+ def load_optional_image(path, default_image):
92
+ if path:
93
+ try:
94
+ return Image.open(path).convert("RGB")
95
+ except Exception as e:
96
+ raise gr.Error(f"فشل تحميل الصورة الاختيارية '{os.path.basename(path)}'. الخطأ: {e}")
97
+ return default_image
98
+
99
+ rgb_left_pil = load_optional_image(rgb_left_image_path, rgb_image_pil)
100
+ rgb_right_pil = load_optional_image(rgb_right_image_path, rgb_image_pil)
101
+ rgb_center_pil = load_optional_image(rgb_center_image_path, rgb_image_pil)
102
 
103
  if lidar_image_path:
104
+ try:
105
+ lidar_array = np.load(lidar_image_path)
106
+ if lidar_array.max() > 0: lidar_array = (lidar_array / lidar_array.max()) * 255.0
107
+ lidar_pil = Image.fromarray(lidar_array.astype(np.uint8)).convert('RGB')
108
+ except Exception as e:
109
+ raise gr.Error(f"فشل تحميل ملف الليدار (.npy). تأكد من أن الملف صحيح. الخطأ: {e}")
110
  else:
111
  lidar_pil = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
 
112
 
113
+ try:
114
+ with open(measurements_path, 'r') as f: m_dict = json.load(f)
115
+ except Exception as e:
116
+ raise gr.Error(f"فشل تحميل أو قراءة ملف القياسات (.json). تأكد من أنه بصيغة صحيحة. الخطأ: {e}")
117
+
118
+ # --- 3. تحويل البيانات إلى تنسورات ---
119
+ front_tensor = transform(rgb_image_pil).unsqueeze(0).to(device)
120
+ left_tensor = transform(rgb_left_pil).unsqueeze(0).to(device)
121
+ right_tensor = transform(rgb_right_pil).unsqueeze(0).to(device)
122
+ center_tensor = transform(rgb_center_pil).unsqueeze(0).to(device)
123
+ lidar_tensor = lidar_transform(lidar_pil).unsqueeze(0).to(device)
124
 
125
  measurements_tensor = torch.tensor([[
126
  m_dict.get('x',0.0), m_dict.get('y',0.0), m_dict.get('theta',0.0), m_dict.get('speed',5.0),
 
130
 
131
  target_point_tensor = torch.tensor([target_point_list], dtype=torch.float32).to(device)
132
 
133
+ inputs = {'rgb': front_tensor, 'rgb_left': left_tensor, 'rgb_right': right_tensor, 'rgb_center': center_tensor, 'lidar': lidar_tensor, 'measurements': measurements_tensor, 'target_point': target_point_tensor}
 
 
 
 
134
 
135
+ # --- 4. تشغيل النموذج ---
136
  with torch.no_grad():
137
+ outputs = model_to_use(inputs)
138
  traffic, waypoints, is_junction, traffic_light, stop_sign, _ = outputs
139
 
140
+ # --- 5. المعالجة اللاحقة والتصوّر ---
141
  speed, pos, theta = m_dict.get('speed',5.0), [m_dict.get('x',0.0), m_dict.get('y',0.0)], m_dict.get('theta',0.0)
142
  traffic_np, waypoints_np = traffic[0].detach().cpu().numpy().reshape(20,20,-1), waypoints[0].detach().cpu().numpy() * WAYPOINT_SCALE_FACTOR
143
  tracker, controller = Tracker(), InterfuserController(ControllerConfig())
144
  updated_traffic = tracker.update_and_predict(traffic_np.copy(), pos, theta, 0)
145
  steer, throttle, brake, metadata = controller.run_step(speed, waypoints_np, is_junction.sigmoid()[0,1].item(), traffic_light.sigmoid()[0,0].item(), stop_sign.sigmoid()[0,1].item(), updated_traffic)
146
 
147
+ # ... (كود الرسم)
148
  map_t0, counts_t0 = render(updated_traffic, t=0)
149
  map_t1, counts_t1 = render(updated_traffic, t=T1_FUTURE_TIME)
150
  map_t2, counts_t2 = render(updated_traffic, t=T2_FUTURE_TIME)
 
161
  'object_counts': {'t0': counts_t0,'t1': counts_t1,'t2': counts_t2}}
162
  dashboard_image = display.run_interface(interface_data)
163
 
164
+ # --- 6. تجهيز المخرجات ---
165
+ control_commands_dict = {"steer": steer, "throttle": throttle, "brake": bool(brake)}
166
+ return Image.fromarray(dashboard_image), control_commands_dict
167
+
168
+ except gr.Error as e:
169
+ raise e # أعد إظهار أخطاء Gradio كما هي
170
  except Exception as e:
171
  print(traceback.format_exc())
172
+ raise gr.Error(f"حدث خطأ غير متوقع أثناء معالجة الإطار: {e}")
173
 
174
 
175
  # ==============================================================================
176
+ # 5. تعريف واجهة Gradio (لا تغيير هنا)
177
  # ==============================================================================
178
+ # ... (كود الواجهة بالكامل يبقى كما هو من النسخة السابقة) ...
179
  available_models = find_available_models()
180
 
181
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo:
 
187
  with gr.Row():
188
  # -- العمود الأيسر: الإعدادات والمدخلات --
189
  with gr.Column(scale=1):
 
190
  with gr.Group():
191
  gr.Markdown("## ⚙️ الخطوة 1: اختر النموذج")
192
  with gr.Row():
 
197
  )
198
  status_textbox = gr.Textbox(label="حالة النموذج", interactive=False)
199
 
 
200
  with gr.Group():
201
  gr.Markdown("## 🗂️ الخطوة 2: ارفع ملفات السيناريو")
202
 
 
215
 
216
  api_run_button = gr.Button("🚀 شغل المحاكاة", variant="primary", scale=2)
217
 
 
218
  with gr.Group():
219
  gr.Markdown("### ✨ أمثلة جاهزة")
220
+ gr.Markdown("انقر على مثال لتعبئة الحقول تلقائياً (يتطلب وجود مجلد `examples`).")
221
  gr.Examples(
222
  examples=[
223
+ [os.path.join(EXAMPLES_DIR, "/content/drive/MyDrive/model2/examples/sample1", "rgb.png"), os.path.join(EXAMPLES_DIR, "sample1", "measurements.json")],
224
+ [os.path.join(EXAMPLES_DIR, "/content/drive/MyDrive/model2/examples/sample2", "rgb.png"), os.path.join(EXAMPLES_DIR, "sample2", "measurements.json")]
225
  ],
 
226
  inputs=[api_rgb_image_path, api_measurements_path],
227
  label="اختر سيناريو اختبار"
228
  )
 
232
  with gr.Group():
233
  gr.Markdown("## 📊 الخطوة 3: شاهد النتائج")
234
  api_output_image = gr.Image(label="لوحة التحكم المرئية (Dashboard)", type="pil", interactive=False)
235
+ api_control_json = gr.JSON(label="أوامر التحكم (JSON)")
 
236
 
237
  # --- ربط منطق الواجهة ---
238
  if available_models:
 
244
  fn=run_single_frame,
245
  inputs=[model_state, api_rgb_image_path, api_rgb_left_image_path, api_rgb_right_image_path,
246
  api_rgb_center_image_path, api_lidar_image_path, api_measurements_path, api_target_point_list],
247
+ outputs=[api_output_image, api_control_json],
248
  api_name="run_single_frame"
249
  )
250
 
251
  # ==============================================================================
252
+ # 6. تشغيل التطبيق
253
  # ==============================================================================
254
  if __name__ == "__main__":
255
  if not available_models: