Update modeling_flow_match.py
Browse files- modeling_flow_match.py +10 -0
modeling_flow_match.py
CHANGED
|
@@ -337,10 +337,20 @@ class FlowMatchRelayModel(PreTrainedModel):
|
|
| 337 |
images = model.sample(n_samples=8, class_label=3)
|
| 338 |
"""
|
| 339 |
config_class = FlowMatchRelayConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
def __init__(self, config):
|
| 342 |
super().__init__(config)
|
| 343 |
self.unet = FlowMatchUNet(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
def forward(self, x, t, class_labels):
|
| 346 |
"""
|
|
|
|
| 337 |
images = model.sample(n_samples=8, class_label=3)
|
| 338 |
"""
|
| 339 |
config_class = FlowMatchRelayConfig
|
| 340 |
+
_tied_weights_keys = []
|
| 341 |
+
_keys_to_ignore_on_load_missing = []
|
| 342 |
+
_keys_to_ignore_on_load_unexpected = []
|
| 343 |
+
_no_split_modules = []
|
| 344 |
+
supports_gradient_checkpointing = False
|
| 345 |
|
| 346 |
def __init__(self, config):
|
| 347 |
super().__init__(config)
|
| 348 |
self.unet = FlowMatchUNet(config)
|
| 349 |
+
self.post_init()
|
| 350 |
+
|
| 351 |
+
def _init_weights(self, module):
|
| 352 |
+
"""No-op — weights loaded from checkpoint or already initialized."""
|
| 353 |
+
pass
|
| 354 |
|
| 355 |
def forward(self, x, t, class_labels):
|
| 356 |
"""
|