Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -28,21 +28,22 @@ from logic import (
|
|
| 28 |
EXAMPLES_DIR = "examples"
|
| 29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
|
|
|
|
| 31 |
MODELS_ON_HUB = {
|
| 32 |
"My Interfuser Model (from Hub)": {
|
| 33 |
"repo_id": "BaseerAI/Interfuser-Baseer-v1",
|
| 34 |
"filename": "best_model.pth",
|
| 35 |
-
"config_key": "
|
| 36 |
}
|
| 37 |
}
|
| 38 |
|
| 39 |
-
#
|
| 40 |
MODELS_SPECIFIC_CONFIGS = {
|
| 41 |
"interfuser_baseline": {
|
| 42 |
"rgb_backbone_name": "r50",
|
| 43 |
"embed_dim": 256,
|
| 44 |
"direct_concat": True,
|
| 45 |
-
"num_cameras": 1 # <--
|
| 46 |
},
|
| 47 |
"interfuser_lightweight": {
|
| 48 |
"rgb_backbone_name": "r26",
|
|
@@ -50,17 +51,16 @@ MODELS_SPECIFIC_CONFIGS = {
|
|
| 50 |
"enc_depth": 4,
|
| 51 |
"dec_depth": 4,
|
| 52 |
"direct_concat": True,
|
| 53 |
-
"num_cameras": 1
|
| 54 |
}
|
| 55 |
}
|
|
|
|
|
|
|
| 56 |
# ==============================================================================
|
| 57 |
-
# 2. الدوال الأساسية
|
| 58 |
# ==============================================================================
|
| 59 |
|
| 60 |
def load_model(model_display_name: str):
|
| 61 |
-
"""
|
| 62 |
-
تحميل النموذج عن طريق تنزيله من Hugging Face Hub.
|
| 63 |
-
"""
|
| 64 |
if not model_display_name or model_display_name not in MODELS_ON_HUB:
|
| 65 |
return None, "الرجاء اختيار نموذج صالح."
|
| 66 |
|
|
@@ -93,18 +93,10 @@ def load_model(model_display_name: str):
|
|
| 93 |
model = build_interfuser_model(model_config)
|
| 94 |
|
| 95 |
try:
|
| 96 |
-
# --- التعديل الرئيسي هنا ---
|
| 97 |
-
# 1. قم بتحميل ملف "نقطة الحفظ" (checkpoint) الكامل
|
| 98 |
checkpoint = torch.load(weights_path, map_location=device)
|
| 99 |
-
|
| 100 |
-
# 2. استخرج قاموس الأوزان الفعلي من داخل "نقطة الحفظ"
|
| 101 |
state_dic = checkpoint['model_state_dict']
|
| 102 |
-
|
| 103 |
-
# 3. الآن قم بتحميل الأوزان الصحيحة في النموذج
|
| 104 |
model.load_state_dict(state_dic)
|
| 105 |
-
|
| 106 |
print(f"تم تحميل أوزان النموذج '{model_display_name}' بنجاح.")
|
| 107 |
-
|
| 108 |
except KeyError:
|
| 109 |
error_msg = "فشل تحميل الأوزان. لم يتم العثور على المفتاح 'model_state_dict' في ملف النموذج. قد يكون هيكل الملف مختلفًا."
|
| 110 |
gr.Warning(error_msg)
|
|
@@ -123,7 +115,6 @@ def run_single_frame(
|
|
| 123 |
model_from_state, rgb_image_path, rgb_left_image_path, rgb_right_image_path,
|
| 124 |
rgb_center_image_path, lidar_image_path, measurements_path, target_point_list
|
| 125 |
):
|
| 126 |
-
# ... (بقية هذا الجزء من الكود لا يحتاج إلى تغيير) ...
|
| 127 |
if model_from_state is None:
|
| 128 |
gr.Warning("النموذج لم يتم تحميله. سيتم محاولة تحميل النموذج الافتراضي.")
|
| 129 |
available_models = list(MODELS_ON_HUB.keys())
|
|
@@ -213,9 +204,8 @@ def run_single_frame(
|
|
| 213 |
raise gr.Error(f"حدث خطأ غير متوقع أثناء معالجة الإطار: {e}")
|
| 214 |
|
| 215 |
# ==============================================================================
|
| 216 |
-
# 5. تعريف واجهة Gradio
|
| 217 |
# ==============================================================================
|
| 218 |
-
# ... (هذا الجزء لا يحتاج إلى تغيير) ...
|
| 219 |
available_models = list(MODELS_ON_HUB.keys())
|
| 220 |
|
| 221 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo:
|
|
|
|
| 28 |
EXAMPLES_DIR = "examples"
|
| 29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
|
| 31 |
+
# --- التعديل 1: تأكد من أن config_key يشير إلى 'interfuser_baseline' ---
|
| 32 |
MODELS_ON_HUB = {
|
| 33 |
"My Interfuser Model (from Hub)": {
|
| 34 |
"repo_id": "BaseerAI/Interfuser-Baseer-v1",
|
| 35 |
"filename": "best_model.pth",
|
| 36 |
+
"config_key": "interfuser_baseline" # <-- يجب أن يكون هذا baseline ليتوافق مع الأوزان
|
| 37 |
}
|
| 38 |
}
|
| 39 |
|
| 40 |
+
# --- التعديل 2: أضف 'num_cameras: 1' إلى إعدادات baseline ---
|
| 41 |
MODELS_SPECIFIC_CONFIGS = {
|
| 42 |
"interfuser_baseline": {
|
| 43 |
"rgb_backbone_name": "r50",
|
| 44 |
"embed_dim": 256,
|
| 45 |
"direct_concat": True,
|
| 46 |
+
"num_cameras": 1 # <-- هذا السطر ضروري جداً
|
| 47 |
},
|
| 48 |
"interfuser_lightweight": {
|
| 49 |
"rgb_backbone_name": "r26",
|
|
|
|
| 51 |
"enc_depth": 4,
|
| 52 |
"dec_depth": 4,
|
| 53 |
"direct_concat": True,
|
| 54 |
+
"num_cameras": 1
|
| 55 |
}
|
| 56 |
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
# ==============================================================================
|
| 60 |
+
# 2. الدوال الأساسية (لا تغييرات هنا)
|
| 61 |
# ==============================================================================
|
| 62 |
|
| 63 |
def load_model(model_display_name: str):
|
|
|
|
|
|
|
|
|
|
| 64 |
if not model_display_name or model_display_name not in MODELS_ON_HUB:
|
| 65 |
return None, "الرجاء اختيار نموذج صالح."
|
| 66 |
|
|
|
|
| 93 |
model = build_interfuser_model(model_config)
|
| 94 |
|
| 95 |
try:
|
|
|
|
|
|
|
| 96 |
checkpoint = torch.load(weights_path, map_location=device)
|
|
|
|
|
|
|
| 97 |
state_dic = checkpoint['model_state_dict']
|
|
|
|
|
|
|
| 98 |
model.load_state_dict(state_dic)
|
|
|
|
| 99 |
print(f"تم تحميل أوزان النموذج '{model_display_name}' بنجاح.")
|
|
|
|
| 100 |
except KeyError:
|
| 101 |
error_msg = "فشل تحميل الأوزان. لم يتم العثور على المفتاح 'model_state_dict' في ملف النموذج. قد يكون هيكل الملف مختلفًا."
|
| 102 |
gr.Warning(error_msg)
|
|
|
|
| 115 |
model_from_state, rgb_image_path, rgb_left_image_path, rgb_right_image_path,
|
| 116 |
rgb_center_image_path, lidar_image_path, measurements_path, target_point_list
|
| 117 |
):
|
|
|
|
| 118 |
if model_from_state is None:
|
| 119 |
gr.Warning("النموذج لم يتم تحميله. سيتم محاولة تحميل النموذج الافتراضي.")
|
| 120 |
available_models = list(MODELS_ON_HUB.keys())
|
|
|
|
| 204 |
raise gr.Error(f"حدث خطأ غير متوقع أثناء معالجة الإطار: {e}")
|
| 205 |
|
| 206 |
# ==============================================================================
|
| 207 |
+
# 5. تعريف واجهة Gradio (لا تغييرات هنا)
|
| 208 |
# ==============================================================================
|
|
|
|
| 209 |
available_models = list(MODELS_ON_HUB.keys())
|
| 210 |
|
| 211 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), css=".gradio-container {max-width: 95% !important;}") as demo:
|