Spaces:
Running
on
Zero
Running
on
Zero
fix model loading bug
Browse files- app.py +11 -10
- metauas.py +2 -2
app.py
CHANGED
|
@@ -49,23 +49,24 @@ metauas_model = MetaUAS(encoder_name,
|
|
| 49 |
fusion_policy
|
| 50 |
)
|
| 51 |
|
| 52 |
-
|
| 53 |
-
model_256 = safely_load_state_dict(metauas_model, "weights/metauas-256.ckpt")
|
| 54 |
-
model_512 = safely_load_state_dict(metauas_model, "weights/metauas-512.ckpt")
|
| 55 |
-
model_256.push_to_hub("csgaobb/MetaUAS-256")
|
| 56 |
-
model_512.push_to_hub("csgaobb/MetaUAS-512")
|
| 57 |
-
|
| 58 |
def process_image(prompt_img, query_img, options):
|
| 59 |
# Load the model based on selected options
|
| 60 |
if 'model-512' in options:
|
| 61 |
-
|
| 62 |
#model = safely_load_state_dict(metauas_model, ckt_path)
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
img_size = 512
|
| 65 |
else:
|
| 66 |
-
|
| 67 |
#model = safely_load_state_dict(metauas_model, ckt_path)
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
img_size = 256
|
| 70 |
|
| 71 |
model.to(device)
|
|
|
|
| 49 |
fusion_policy
|
| 50 |
)
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def process_image(prompt_img, query_img, options):
|
| 53 |
# Load the model based on selected options
|
| 54 |
if 'model-512' in options:
|
| 55 |
+
ckt_path = "weights/metauas-512.ckpt"
|
| 56 |
#model = safely_load_state_dict(metauas_model, ckt_path)
|
| 57 |
+
|
| 58 |
+
ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename=ckt_path)
|
| 59 |
+
metauas_model.load_state_dict(torch.load(ckpt_path))
|
| 60 |
+
|
| 61 |
+
#model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
|
| 62 |
img_size = 512
|
| 63 |
else:
|
| 64 |
+
ckt_path = 'weights/metauas-256.ckpt'
|
| 65 |
#model = safely_load_state_dict(metauas_model, ckt_path)
|
| 66 |
+
|
| 67 |
+
ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename=ckt_path)
|
| 68 |
+
metauas_model.load_state_dict(torch.load(ckpt_path))
|
| 69 |
+
#model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
|
| 70 |
img_size = 256
|
| 71 |
|
| 72 |
model.to(device)
|
metauas.py
CHANGED
|
@@ -132,7 +132,7 @@ class AlignmentLayer(nn.Module):
|
|
| 132 |
return aligned_features
|
| 133 |
|
| 134 |
|
| 135 |
-
class MetaUAS(pl.LightningModule
|
| 136 |
def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy):
|
| 137 |
super().__init__()
|
| 138 |
|
|
@@ -267,7 +267,7 @@ class MetaUAS(pl.LightningModule, PyTorchModelHubMixin):
|
|
| 267 |
stride=1,
|
| 268 |
padding=0,
|
| 269 |
)
|
| 270 |
-
|
| 271 |
def forward(self, batch):
|
| 272 |
query_input = self.preprocess(batch["query_image"])
|
| 273 |
prompt_input = self.preprocess(batch["prompt_image"])
|
|
|
|
| 132 |
return aligned_features
|
| 133 |
|
| 134 |
|
| 135 |
+
class MetaUAS(pl.LightningModule):
|
| 136 |
def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy):
|
| 137 |
super().__init__()
|
| 138 |
|
|
|
|
| 267 |
stride=1,
|
| 268 |
padding=0,
|
| 269 |
)
|
| 270 |
+
|
| 271 |
def forward(self, batch):
|
| 272 |
query_input = self.preprocess(batch["query_image"])
|
| 273 |
prompt_input = self.preprocess(batch["prompt_image"])
|