Spaces:
Running
on
Zero
Running
on
Zero
2025-08-01 08:50 🐛
Browse files
app.py
CHANGED
|
@@ -20,8 +20,6 @@ mean = (0.485, 0.456, 0.406)
|
|
| 20 |
std = (0.229, 0.224, 0.225)
|
| 21 |
alpha = 0.8
|
| 22 |
EPS = 1e-8
|
| 23 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
-
device = torch.device("cuda")
|
| 25 |
loaded_model = None
|
| 26 |
current_model_config = {"variant": None, "dataset": None, "metric": None}
|
| 27 |
|
|
@@ -78,18 +76,18 @@ def update_model_if_needed(variant_dataset_metric: str):
|
|
| 78 |
else:
|
| 79 |
return f"Unknown dataset: {dataset}"
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
current_model_config["dataset"] != dataset_name or
|
| 84 |
current_model_config["metric"] != metric):
|
| 85 |
|
| 86 |
-
print(f"
|
| 87 |
-
loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
|
| 88 |
current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
|
| 89 |
-
|
|
|
|
| 90 |
else:
|
| 91 |
-
print(f"
|
| 92 |
-
return f"Model
|
| 93 |
|
| 94 |
|
| 95 |
# -----------------------------
|
|
@@ -305,13 +303,13 @@ def predict(image: Image.Image, variant_dataset_metric: str):
|
|
| 305 |
"""
|
| 306 |
global loaded_model, current_model_config
|
| 307 |
|
|
|
|
|
|
|
|
|
|
| 308 |
# 如果选择的是分割线,返回错误信息
|
| 309 |
if "━━━━━━" in variant_dataset_metric:
|
| 310 |
return image, None, None, "⚠️ Please select a valid model configuration", None, None, None
|
| 311 |
|
| 312 |
-
# 确保模型正确加载
|
| 313 |
-
update_model_if_needed(variant_dataset_metric)
|
| 314 |
-
|
| 315 |
parts = variant_dataset_metric.split(" @ ")
|
| 316 |
if len(parts) != 3:
|
| 317 |
return image, None, None, "❌ Invalid model configuration format", None, None, None
|
|
@@ -329,6 +327,16 @@ def predict(image: Image.Image, variant_dataset_metric: str):
|
|
| 329 |
else:
|
| 330 |
return image, None, None, f"❌ Unknown dataset: {dataset}", None, None, None
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
if not hasattr(loaded_model, "input_size"):
|
| 333 |
if dataset_name == "sha":
|
| 334 |
loaded_model.input_size = 224
|
|
@@ -758,9 +766,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="ZIP Crowd Counting") as d
|
|
| 758 |
outputs=[model_status]
|
| 759 |
)
|
| 760 |
|
| 761 |
-
#
|
| 762 |
demo.load(
|
| 763 |
-
fn=lambda: f"
|
| 764 |
outputs=[model_status]
|
| 765 |
)
|
| 766 |
|
|
|
|
| 20 |
std = (0.229, 0.224, 0.225)
|
| 21 |
alpha = 0.8
|
| 22 |
EPS = 1e-8
|
|
|
|
|
|
|
| 23 |
loaded_model = None
|
| 24 |
current_model_config = {"variant": None, "dataset": None, "metric": None}
|
| 25 |
|
|
|
|
| 76 |
else:
|
| 77 |
return f"Unknown dataset: {dataset}"
|
| 78 |
|
| 79 |
+
# 只更新配置,不在主进程中加载模型
|
| 80 |
+
if (current_model_config["variant"] != variant or
|
| 81 |
current_model_config["dataset"] != dataset_name or
|
| 82 |
current_model_config["metric"] != metric):
|
| 83 |
|
| 84 |
+
print(f"Model configuration updated: {variant} @ {dataset} with {metric} metric")
|
|
|
|
| 85 |
current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
|
| 86 |
+
loaded_model = None # 重置模型,将在GPU进程中重新加载
|
| 87 |
+
return f"Model configuration set: {variant} @ {dataset} ({metric})"
|
| 88 |
else:
|
| 89 |
+
print(f"Model configuration unchanged: {variant} @ {dataset} with {metric} metric")
|
| 90 |
+
return f"Model configuration: {variant} @ {dataset} ({metric})"
|
| 91 |
|
| 92 |
|
| 93 |
# -----------------------------
|
|
|
|
| 303 |
"""
|
| 304 |
global loaded_model, current_model_config
|
| 305 |
|
| 306 |
+
# 在GPU进程中定义device
|
| 307 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 308 |
+
|
| 309 |
# 如果选择的是分割线,返回错误信息
|
| 310 |
if "━━━━━━" in variant_dataset_metric:
|
| 311 |
return image, None, None, "⚠️ Please select a valid model configuration", None, None, None
|
| 312 |
|
|
|
|
|
|
|
|
|
|
| 313 |
parts = variant_dataset_metric.split(" @ ")
|
| 314 |
if len(parts) != 3:
|
| 315 |
return image, None, None, "❌ Invalid model configuration format", None, None, None
|
|
|
|
| 327 |
else:
|
| 328 |
return image, None, None, f"❌ Unknown dataset: {dataset}", None, None, None
|
| 329 |
|
| 330 |
+
# 在GPU进程中加载模型(如果需要)
|
| 331 |
+
if (loaded_model is None or
|
| 332 |
+
current_model_config["variant"] != variant or
|
| 333 |
+
current_model_config["dataset"] != dataset_name or
|
| 334 |
+
current_model_config["metric"] != metric):
|
| 335 |
+
|
| 336 |
+
print(f"Loading model in GPU process: {variant} @ {dataset} with {metric} metric")
|
| 337 |
+
loaded_model = load_model(variant=variant, dataset=dataset_name, metric=metric)
|
| 338 |
+
current_model_config = {"variant": variant, "dataset": dataset_name, "metric": metric}
|
| 339 |
+
|
| 340 |
if not hasattr(loaded_model, "input_size"):
|
| 341 |
if dataset_name == "sha":
|
| 342 |
loaded_model.input_size = 224
|
|
|
|
| 766 |
outputs=[model_status]
|
| 767 |
)
|
| 768 |
|
| 769 |
+
# 页面加载时设置默认模型配置(不在主进程中加载模型)
|
| 770 |
demo.load(
|
| 771 |
+
fn=lambda: f"✅ {update_model_if_needed('ZIP-B @ NWPU-Crowd @ MAE')}",
|
| 772 |
outputs=[model_status]
|
| 773 |
)
|
| 774 |
|