Upload scheduler/scheduling_ncsn.py with huggingface_hub
Browse files
scheduler/scheduling_ncsn.py
CHANGED
|
@@ -15,11 +15,13 @@ from einops import rearrange
|
|
| 15 |
|
| 16 |
|
| 17 |
@dataclass
|
| 18 |
-
class
|
| 19 |
-
"""Annealed Langevin
|
| 20 |
|
| 21 |
|
| 22 |
-
class
|
|
|
|
|
|
|
| 23 |
order = 1
|
| 24 |
|
| 25 |
@register_to_config
|
|
@@ -106,13 +108,13 @@ class AnnealedLangevinDynamicScheduler(SchedulerMixin, ConfigMixin): # type: ig
|
|
| 106 |
samples: torch.Tensor,
|
| 107 |
return_dict: bool = True,
|
| 108 |
**kwargs,
|
| 109 |
-
) -> Union[
|
| 110 |
z = torch.randn_like(samples)
|
| 111 |
step_size = self.step_size[timestep]
|
| 112 |
samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
|
| 113 |
|
| 114 |
if return_dict:
|
| 115 |
-
return
|
| 116 |
else:
|
| 117 |
return (samples,)
|
| 118 |
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
@dataclass
|
| 18 |
+
class AnnealedLangevinDynamicsOutput(SchedulerOutput):
|
| 19 |
+
"""Annealed Langevin Dynamics output class."""
|
| 20 |
|
| 21 |
|
| 22 |
+
class AnnealedLangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): # type: ignore
|
| 23 |
+
"""Annealed Langevin Dynamics scheduler for Noise Conditional Score Network (NCSN)."""
|
| 24 |
+
|
| 25 |
order = 1
|
| 26 |
|
| 27 |
@register_to_config
|
|
|
|
| 108 |
samples: torch.Tensor,
|
| 109 |
return_dict: bool = True,
|
| 110 |
**kwargs,
|
| 111 |
+
) -> Union[AnnealedLangevinDynamicsOutput, Tuple]:
|
| 112 |
z = torch.randn_like(samples)
|
| 113 |
step_size = self.step_size[timestep]
|
| 114 |
samples = samples + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
|
| 115 |
|
| 116 |
if return_dict:
|
| 117 |
+
return AnnealedLangevinDynamicsOutput(prev_sample=samples)
|
| 118 |
else:
|
| 119 |
return (samples,)
|
| 120 |
|