Update app.py
Browse files
app.py
CHANGED
|
@@ -34,13 +34,16 @@ with open(config, 'r') as f:
|
|
| 34 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 35 |
hf_weight = hf_hub_download(repo_id=f"ZidongC/PanDA", filename=f"panda_large.pth", repo_type="model")
|
| 36 |
state_dict = torch.load(hf_weight, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
model = make(config['model'])
|
| 38 |
-
if any(key.startswith('module') for key in state_dict.keys()):
|
| 39 |
-
|
| 40 |
model_state_dict = model.state_dict()
|
| 41 |
-
model.load_state_dict({k: v for k, v in
|
| 42 |
model = model.to(DEVICE).eval()
|
| 43 |
-
model = model.module
|
| 44 |
|
| 45 |
title = "# PanDA"
|
| 46 |
description = """Official demo for **PanDA**.
|
|
|
|
| 34 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 35 |
hf_weight = hf_hub_download(repo_id=f"ZidongC/PanDA", filename=f"panda_large.pth", repo_type="model")
|
| 36 |
state_dict = torch.load(hf_weight, map_location="cpu")
|
| 37 |
+
new_state_dict = {}
|
| 38 |
+
for key, value in state_dict.items():
|
| 39 |
+
new_key = key[7:] if key.startswith('module.') else key
|
| 40 |
+
new_state_dict[new_key] = value
|
| 41 |
model = make(config['model'])
|
| 42 |
+
# if any(key.startswith('module') for key in state_dict.keys()):
|
| 43 |
+
# model = nn.DataParallel(model)
|
| 44 |
model_state_dict = model.state_dict()
|
| 45 |
+
model.load_state_dict({k: v for k, v in new_state_dict.items() if k in model_state_dict})
|
| 46 |
model = model.to(DEVICE).eval()
|
|
|
|
| 47 |
|
| 48 |
title = "# PanDA"
|
| 49 |
description = """Official demo for **PanDA**.
|