maxwoe commited on
Commit
ed8d2c4
·
verified ·
1 Parent(s): 3a3adaa

Upload model_cgd.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_cgd.py +37 -0
model_cgd.py CHANGED
@@ -139,6 +139,43 @@ class CGDAngleEstimation(pl.LightningModule):
139
  return model
140
  raise FileNotFoundError("Checkpoint file not found")
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
143
  """Forward pass returning probability distribution over angles."""
144
  logits = self.model(x)
 
139
  return model
140
  raise FileNotFoundError("Checkpoint file not found")
141
 
142
+ @classmethod
143
+ def from_pretrained(cls, repo_id, model_name=None):
144
+ """Load a pretrained model from HuggingFace Hub.
145
+
146
+ Args:
147
+ repo_id: HuggingFace repo ID (e.g. "maxwoe/image-rotation-angle-estimation")
148
+ model_name: Display name or checkpoint filename from config.json.
149
+ Defaults to the default model.
150
+ """
151
+ import json
152
+ from huggingface_hub import hf_hub_download
153
+
154
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
155
+ with open(config_path) as f:
156
+ config = json.load(f)
157
+
158
+ if model_name is None:
159
+ model_name = config["default_model"]
160
+
161
+ # Look up by display name or by filename
162
+ if model_name in config["models"]:
163
+ model_info = config["models"][model_name]
164
+ else:
165
+ model_info = None
166
+ for info in config["models"].values():
167
+ if info["filename"] == model_name:
168
+ model_info = info
169
+ break
170
+ if model_info is None:
171
+ available = [i["filename"] for i in config["models"].values()]
172
+ raise ValueError(f"Unknown model: {model_name}. Available: {available}")
173
+
174
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=model_info["filename"])
175
+ model = cls.try_load(checkpoint_path=ckpt_path, image_size=model_info["input_size"])
176
+ model.eval()
177
+ return model
178
+
179
  def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
180
  """Forward pass returning probability distribution over angles."""
181
  logits = self.model(x)