Spaces:
Sleeping
Sleeping
new ckpt
Browse files- .gitignore +5 -0
- factories.py +10 -16
.gitignore
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
# Mac
|
| 2 |
.DS_Store
|
| 3 |
|
|
|
|
| 4 |
.idea/
|
|
|
|
|
|
|
| 5 |
.ipynb_checkpoints/
|
|
|
|
|
|
|
| 6 |
__pycache__/
|
|
|
|
| 1 |
# Mac
|
| 2 |
.DS_Store
|
| 3 |
|
| 4 |
+
# PyCharm
|
| 5 |
.idea/
|
| 6 |
+
|
| 7 |
+
# Jupyter notebooks
|
| 8 |
.ipynb_checkpoints/
|
| 9 |
+
|
| 10 |
+
# Python
|
| 11 |
__pycache__/
|
factories.py
CHANGED
|
@@ -212,29 +212,24 @@ class EvalModel(torch.nn.Module):
|
|
| 212 |
def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
|
| 213 |
"""Load the model we want to evaluate."""
|
| 214 |
super().__init__()
|
| 215 |
-
self.
|
| 216 |
self.ckpt_pth = ckpt_pth
|
| 217 |
-
self.name
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
if self.base_name == "unext_emb_physics_config_C":
|
| 221 |
if self.ckpt_pth == "":
|
| 222 |
-
self.ckpt_pth = "ckpt/
|
| 223 |
-
self.model = get_model(model_name=self.
|
| 224 |
device='cpu',
|
| 225 |
**DEFAULT_MODEL_PARAMS)
|
| 226 |
|
| 227 |
-
# load model checkpoint
|
| 228 |
-
state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)
|
| 229 |
-
|
| 230 |
self.model.load_state_dict(state_dict)
|
| 231 |
self.model.to(device_str)
|
| 232 |
self.model.eval()
|
| 233 |
|
| 234 |
-
# add epoch in the model name
|
| 235 |
-
epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch']
|
| 236 |
-
self.name = self.name + f"+{epoch}"
|
| 237 |
-
|
| 238 |
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
|
| 239 |
return self.model(y, physics=physics)
|
| 240 |
|
|
@@ -250,9 +245,8 @@ class BaselineModel(torch.nn.Module):
|
|
| 250 |
|
| 251 |
def __init__(self, model_name: str, device_str: str = "cpu") -> None:
|
| 252 |
super().__init__()
|
| 253 |
-
self.
|
| 254 |
self.ckpt_pth = ""
|
| 255 |
-
self.name = self.base_name
|
| 256 |
if self.name not in self.all_baselines:
|
| 257 |
raise ValueError(f"{self.name} is unavailable.")
|
| 258 |
elif self.name == "DPIR":
|
|
|
|
| 212 |
def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
|
| 213 |
"""Load the model we want to evaluate."""
|
| 214 |
super().__init__()
|
| 215 |
+
self.name = model_name
|
| 216 |
self.ckpt_pth = ckpt_pth
|
| 217 |
+
if self.name not in self.all_models:
|
| 218 |
+
raise ValueError(f"{self.name} is unavailable.")
|
| 219 |
+
if self.name == "unext_emb_physics_config_C":
|
|
|
|
| 220 |
if self.ckpt_pth == "":
|
| 221 |
+
self.ckpt_pth = "ckpt/ram.pth.tar"
|
| 222 |
+
self.model = get_model(model_name=self.name,
|
| 223 |
device='cpu',
|
| 224 |
**DEFAULT_MODEL_PARAMS)
|
| 225 |
|
| 226 |
+
# load model checkpoint on cpu
|
| 227 |
+
state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)
|
| 228 |
+
|
| 229 |
self.model.load_state_dict(state_dict)
|
| 230 |
self.model.to(device_str)
|
| 231 |
self.model.eval()
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
|
| 234 |
return self.model(y, physics=physics)
|
| 235 |
|
|
|
|
| 245 |
|
| 246 |
def __init__(self, model_name: str, device_str: str = "cpu") -> None:
|
| 247 |
super().__init__()
|
| 248 |
+
self.name = model_name
|
| 249 |
self.ckpt_pth = ""
|
|
|
|
| 250 |
if self.name not in self.all_baselines:
|
| 251 |
raise ValueError(f"{self.name} is unavailable.")
|
| 252 |
elif self.name == "DPIR":
|