Upload model_cgd.py with huggingface_hub
Browse files- model_cgd.py +37 -0
model_cgd.py
CHANGED
|
@@ -139,6 +139,43 @@ class CGDAngleEstimation(pl.LightningModule):
|
|
| 139 |
return model
|
| 140 |
raise FileNotFoundError("Checkpoint file not found")
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
|
| 143 |
"""Forward pass returning probability distribution over angles."""
|
| 144 |
logits = self.model(x)
|
|
|
|
| 139 |
return model
|
| 140 |
raise FileNotFoundError("Checkpoint file not found")
|
| 141 |
|
| 142 |
+
@classmethod
|
| 143 |
+
def from_pretrained(cls, repo_id, model_name=None):
|
| 144 |
+
"""Load a pretrained model from HuggingFace Hub.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
repo_id: HuggingFace repo ID (e.g. "maxwoe/image-rotation-angle-estimation")
|
| 148 |
+
model_name: Display name or checkpoint filename from config.json.
|
| 149 |
+
Defaults to the default model.
|
| 150 |
+
"""
|
| 151 |
+
import json
|
| 152 |
+
from huggingface_hub import hf_hub_download
|
| 153 |
+
|
| 154 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
|
| 155 |
+
with open(config_path) as f:
|
| 156 |
+
config = json.load(f)
|
| 157 |
+
|
| 158 |
+
if model_name is None:
|
| 159 |
+
model_name = config["default_model"]
|
| 160 |
+
|
| 161 |
+
# Look up by display name or by filename
|
| 162 |
+
if model_name in config["models"]:
|
| 163 |
+
model_info = config["models"][model_name]
|
| 164 |
+
else:
|
| 165 |
+
model_info = None
|
| 166 |
+
for info in config["models"].values():
|
| 167 |
+
if info["filename"] == model_name:
|
| 168 |
+
model_info = info
|
| 169 |
+
break
|
| 170 |
+
if model_info is None:
|
| 171 |
+
available = [i["filename"] for i in config["models"].values()]
|
| 172 |
+
raise ValueError(f"Unknown model: {model_name}. Available: {available}")
|
| 173 |
+
|
| 174 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=model_info["filename"])
|
| 175 |
+
model = cls.try_load(checkpoint_path=ckpt_path, image_size=model_info["input_size"])
|
| 176 |
+
model.eval()
|
| 177 |
+
return model
|
| 178 |
+
|
| 179 |
def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
|
| 180 |
"""Forward pass returning probability distribution over angles."""
|
| 181 |
logits = self.model(x)
|