BiliSakura commited on
Commit
3d3821e
·
verified ·
1 Parent(s): 4f303d7

Update all files for DiffusionSat-Single-256

Browse files
Files changed (1) hide show
  1. unet/sat_unet.py +17 -2
unet/sat_unet.py CHANGED
@@ -18,8 +18,23 @@ class SatUNet(UNet2DConditionModel):
18
 
19
  _supports_gradient_checkpointing = True
20
 
21
- def __init__(self, *args, use_metadata: bool = True, num_metadata: int = 7, **kwargs):
22
- super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Track custom config entries for save/load parity with the base model.
25
  self.register_to_config(use_metadata=use_metadata, num_metadata=num_metadata)
 
18
 
19
  _supports_gradient_checkpointing = True
20
 
21
+ @classmethod
22
+ def from_config(cls, config: Optional[Dict[str, Any]] = None, **kwargs):
23
+ """Load config using parent's expected keys so all UNet params (e.g. cross_attention_dim, use_linear_projection) are passed through."""
24
+ if config is None:
25
+ raise ValueError("Please provide a config.")
26
+ if not isinstance(config, dict):
27
+ config, kwargs = cls.load_config(config, return_unused_kwargs=True, **kwargs)
28
+ # Use parent's extract_init_dict so all UNet2DConditionModel params are in init_dict
29
+ init_dict, unused_kwargs, hidden_dict = UNet2DConditionModel.extract_init_dict(config, **kwargs)
30
+ init_dict.setdefault("use_metadata", True)
31
+ init_dict.setdefault("num_metadata", 7)
32
+ model = cls(**init_dict)
33
+ model.register_to_config(**hidden_dict)
34
+ return model
35
+
36
+ def __init__(self, *args, use_metadata: bool = True, num_metadata: int = 7, use_linear_projection: bool = False, **kwargs):
37
+ super().__init__(*args, use_linear_projection=use_linear_projection, **kwargs)
38
 
39
  # Track custom config entries for save/load parity with the base model.
40
  self.register_to_config(use_metadata=use_metadata, num_metadata=num_metadata)