mohammed-aljafry commited on
Commit
20bf6fe
·
verified ·
1 Parent(s): 08cd1fb

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +9 -19
app.py CHANGED
@@ -28,21 +28,22 @@ from logic import (
28
  EXAMPLES_DIR = "examples"
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
 
31
  MODELS_ON_HUB = {
32
  "My Interfuser Model (from Hub)": {
33
  "repo_id": "BaseerAI/Interfuser-Baseer-v1",
34
  "filename": "best_model.pth",
35
- "config_key": "interfuser_lightweight"
36
  }
37
  }
38
 
39
- # إعدادات البناء الخاصة بكل نموذج
40
  MODELS_SPECIFIC_CONFIGS = {
41
  "interfuser_baseline": {
42
  "rgb_backbone_name": "r50",
43
  "embed_dim": 256,
44
  "direct_concat": True,
45
- "num_cameras": 1 # <-- أضف هذا السطر
46
  },
47
  "interfuser_lightweight": {
48
  "rgb_backbone_name": "r26",
@@ -50,17 +51,16 @@ MODELS_SPECIFIC_CONFIGS = {
50
  "enc_depth": 4,
51
  "dec_depth": 4,
52
  "direct_concat": True,
53
- "num_cameras": 1 # <-- أضف هذا السطر أيضاً للاحتياط
54
  }
55
  }
 
 
56
  # ==============================================================================
57
- # 2. الدوال الأساسية
58
  # ==============================================================================
59
 
60
  def load_model(model_display_name: str):
61
- """
62
- تحميل النموذج عن طريق تنزيله من Hugging Face Hub.
63
- """
64
  if not model_display_name or model_display_name not in MODELS_ON_HUB:
65
  return None, "الرجاء اختيار نموذج صالح."
66
 
@@ -93,18 +93,10 @@ def load_model(model_display_name: str):
93
  model = build_interfuser_model(model_config)
94
 
95
  try:
96
- # --- التعديل الرئيسي هنا ---
97
- # 1. قم بتحميل ملف "نقطة الحفظ" (checkpoint) الكامل
98
  checkpoint = torch.load(weights_path, map_location=device)
99
-
100
- # 2. استخرج قاموس الأوزان الفعلي من داخل "نقطة الحفظ"
101
  state_dic = checkpoint['model_state_dict']
102
-
103
- # 3. الآن قم بتحميل الأوزان الصحيحة في النموذج
104
  model.load_state_dict(state_dic)
105
-
106
  print(f"تم تحميل أوزان النموذج '{model_display_name}' بنجاح.")
107
-
108
  except KeyError:
109
  error_msg = "فشل تحميل الأوزان. لم يتم العثور على المفتاح 'model_state_dict' في ملف النموذج. قد يكون هيكل الملف مختلفًا."
110
  gr.Warning(error_msg)
@@ -123,7 +115,6 @@ def run_single_frame(
123
  model_from_state, rgb_image_path, rgb_left_image_path, rgb_right_image_path,
124
  rgb_center_image_path, lidar_image_path, measurements_path, target_point_list
125
  ):
126
- # ... (بقية هذا الجزء من الكود لا يحتاج إلى تغيير) ...
127
  if model_from_state is None:
128
  gr.Warning("النموذج لم يتم تحميله. سيتم محاولة تحميل النموذج الافتراضي.")
129
  available_models = list(MODELS_ON_HUB.keys())
@@ -213,9 +204,8 @@ def run_single_frame(
213
  raise gr.Error(f"حدث خطأ غير متوقع أثناء معالجة الإطار: {e}")
214
 
215
  # ==============================================================================
216
- # 5. تعريف واجهة Gradio
217
  # ==============================================================================
218
- # ... (هذا الجزء لا يحتاج إلى تغيير) ...
219
  available_models = list(MODELS_ON_HUB.keys())
220
 
221
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo:
 
28
  EXAMPLES_DIR = "examples"
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
+ # --- التعديل 1: تأكد من أن config_key يشير إلى 'interfuser_baseline' ---
32
  MODELS_ON_HUB = {
33
  "My Interfuser Model (from Hub)": {
34
  "repo_id": "BaseerAI/Interfuser-Baseer-v1",
35
  "filename": "best_model.pth",
36
+ "config_key": "interfuser_baseline" # <-- يجب أن يكون هذا baseline ليتوافق مع الأوزان
37
  }
38
  }
39
 
40
+ # --- التعديل 2: أضف 'num_cameras: 1' إلى إعدادات baseline ---
41
  MODELS_SPECIFIC_CONFIGS = {
42
  "interfuser_baseline": {
43
  "rgb_backbone_name": "r50",
44
  "embed_dim": 256,
45
  "direct_concat": True,
46
+ "num_cameras": 1 # <-- هذا السطر ضروري جداً
47
  },
48
  "interfuser_lightweight": {
49
  "rgb_backbone_name": "r26",
 
51
  "enc_depth": 4,
52
  "dec_depth": 4,
53
  "direct_concat": True,
54
+ "num_cameras": 1
55
  }
56
  }
57
+
58
+
59
  # ==============================================================================
60
+ # 2. الدوال الأساسية (لا تغييرات هنا)
61
  # ==============================================================================
62
 
63
  def load_model(model_display_name: str):
 
 
 
64
  if not model_display_name or model_display_name not in MODELS_ON_HUB:
65
  return None, "الرجاء اختيار نموذج صالح."
66
 
 
93
  model = build_interfuser_model(model_config)
94
 
95
  try:
 
 
96
  checkpoint = torch.load(weights_path, map_location=device)
 
 
97
  state_dic = checkpoint['model_state_dict']
 
 
98
  model.load_state_dict(state_dic)
 
99
  print(f"تم تحميل أوزان النموذج '{model_display_name}' بنجاح.")
 
100
  except KeyError:
101
  error_msg = "فشل تحميل الأوزان. لم يتم العثور على المفتاح 'model_state_dict' في ملف النموذج. قد يكون هيكل الملف مختلفًا."
102
  gr.Warning(error_msg)
 
115
  model_from_state, rgb_image_path, rgb_left_image_path, rgb_right_image_path,
116
  rgb_center_image_path, lidar_image_path, measurements_path, target_point_list
117
  ):
 
118
  if model_from_state is None:
119
  gr.Warning("النموذج لم يتم تحميله. سيتم محاولة تحميل النموذج الافتراضي.")
120
  available_models = list(MODELS_ON_HUB.keys())
 
204
  raise gr.Error(f"حدث خطأ غير متوقع أثناء معالجة الإطار: {e}")
205
 
206
  # ==============================================================================
207
+ # 5. تعريف واجهة Gradio (لا تغييرات هنا)
208
  # ==============================================================================
 
209
  available_models = list(MODELS_ON_HUB.keys())
210
 
211
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo: