Update all files for DiffusionSat-Single-256
Browse files- unet/sat_unet.py +17 -2
unet/sat_unet.py
CHANGED
|
@@ -18,8 +18,23 @@ class SatUNet(UNet2DConditionModel):
|
|
| 18 |
|
| 19 |
_supports_gradient_checkpointing = True
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Track custom config entries for save/load parity with the base model.
|
| 25 |
self.register_to_config(use_metadata=use_metadata, num_metadata=num_metadata)
|
|
|
|
| 18 |
|
| 19 |
_supports_gradient_checkpointing = True
|
| 20 |
|
| 21 |
+
@classmethod
|
| 22 |
+
def from_config(cls, config: Optional[Dict[str, Any]] = None, **kwargs):
|
| 23 |
+
"""Load config using parent's expected keys so all UNet params (e.g. cross_attention_dim, use_linear_projection) are passed through."""
|
| 24 |
+
if config is None:
|
| 25 |
+
raise ValueError("Please provide a config.")
|
| 26 |
+
if not isinstance(config, dict):
|
| 27 |
+
config, kwargs = cls.load_config(config, return_unused_kwargs=True, **kwargs)
|
| 28 |
+
# Use parent's extract_init_dict so all UNet2DConditionModel params are in init_dict
|
| 29 |
+
init_dict, unused_kwargs, hidden_dict = UNet2DConditionModel.extract_init_dict(config, **kwargs)
|
| 30 |
+
init_dict.setdefault("use_metadata", True)
|
| 31 |
+
init_dict.setdefault("num_metadata", 7)
|
| 32 |
+
model = cls(**init_dict)
|
| 33 |
+
model.register_to_config(**hidden_dict)
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
def __init__(self, *args, use_metadata: bool = True, num_metadata: int = 7, use_linear_projection: bool = False, **kwargs):
|
| 37 |
+
super().__init__(*args, use_linear_projection=use_linear_projection, **kwargs)
|
| 38 |
|
| 39 |
# Track custom config entries for save/load parity with the base model.
|
| 40 |
self.register_to_config(use_metadata=use_metadata, num_metadata=num_metadata)
|