Unconditional Image Generation
Diffusers
Safetensors
English
lightningdit
image-generation
class-conditional
imagenet
flow-matching
Instructions to use BiliSakura/LightningDiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/LightningDiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/LightningDiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Update LightningDit-XL-1-256/scheduler/scheduling_flow_match_lightningdit.py
Browse files
LightningDit-XL-1-256/scheduler/scheduling_flow_match_lightningdit.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
# you may not use this file except in compliance with the License.
|
| 5 |
|
| 6 |
from dataclasses import dataclass
|
| 7 |
-
from typing import Optional, Tuple
|
| 8 |
|
| 9 |
import torch
|
| 10 |
|
|
@@ -43,11 +43,17 @@ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
|
| 43 |
order = 1
|
| 44 |
|
| 45 |
@register_to_config
|
| 46 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
if path_type not in {"linear", "cosine"}:
|
| 48 |
raise ValueError("path_type must be either 'linear' or 'cosine'.")
|
| 49 |
self.path_type = path_type
|
| 50 |
self.num_train_timesteps = num_train_timesteps
|
|
|
|
| 51 |
self.timesteps = torch.linspace(0.0, 1.0, num_train_timesteps + 1, dtype=torch.float64)
|
| 52 |
|
| 53 |
@staticmethod
|
|
@@ -60,10 +66,11 @@ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
|
| 60 |
self,
|
| 61 |
num_inference_steps: int,
|
| 62 |
device: Optional[torch.device] = None,
|
| 63 |
-
timestep_shift: float =
|
| 64 |
):
|
|
|
|
| 65 |
timesteps = torch.linspace(0.0, 1.0, num_inference_steps + 1, dtype=torch.float64)
|
| 66 |
-
timesteps = self._apply_timestep_shift(timesteps,
|
| 67 |
self.timesteps = timesteps.to(device=device)
|
| 68 |
return self.timesteps
|
| 69 |
|
|
@@ -74,7 +81,9 @@ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
|
| 74 |
sample: torch.Tensor,
|
| 75 |
next_timestep: torch.Tensor,
|
| 76 |
return_dict: bool = True,
|
|
|
|
| 77 |
) -> LightningDiTFlowMatchSchedulerOutput:
|
|
|
|
| 78 |
sample_dtype = sample.dtype
|
| 79 |
sample = sample.to(dtype=torch.float64)
|
| 80 |
model_output = model_output.to(dtype=torch.float64)
|
|
@@ -94,7 +103,9 @@ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
|
|
| 94 |
sample: torch.Tensor,
|
| 95 |
next_timestep: torch.Tensor,
|
| 96 |
return_dict: bool = True,
|
|
|
|
| 97 |
) -> LightningDiTFlowMatchSchedulerOutput:
|
|
|
|
| 98 |
sample_dtype = sample.dtype
|
| 99 |
sample = sample.to(dtype=torch.float64)
|
| 100 |
model_output = model_output.to(dtype=torch.float64)
|
|
|
|
| 4 |
# you may not use this file except in compliance with the License.
|
| 5 |
|
| 6 |
from dataclasses import dataclass
|
| 7 |
+
from typing import List, Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
|
|
|
|
| 43 |
order = 1
|
| 44 |
|
| 45 |
@register_to_config
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
path_type: str = "linear",
|
| 49 |
+
num_train_timesteps: int = 1000,
|
| 50 |
+
shift: float = 0.3,
|
| 51 |
+
):
|
| 52 |
if path_type not in {"linear", "cosine"}:
|
| 53 |
raise ValueError("path_type must be either 'linear' or 'cosine'.")
|
| 54 |
self.path_type = path_type
|
| 55 |
self.num_train_timesteps = num_train_timesteps
|
| 56 |
+
self.shift = shift
|
| 57 |
self.timesteps = torch.linspace(0.0, 1.0, num_train_timesteps + 1, dtype=torch.float64)
|
| 58 |
|
| 59 |
@staticmethod
|
|
|
|
| 66 |
self,
|
| 67 |
num_inference_steps: int,
|
| 68 |
device: Optional[torch.device] = None,
|
| 69 |
+
timestep_shift: Optional[float] = None,
|
| 70 |
):
|
| 71 |
+
shift = self.shift if timestep_shift is None else timestep_shift
|
| 72 |
timesteps = torch.linspace(0.0, 1.0, num_inference_steps + 1, dtype=torch.float64)
|
| 73 |
+
timesteps = self._apply_timestep_shift(timesteps, shift)
|
| 74 |
self.timesteps = timesteps.to(device=device)
|
| 75 |
return self.timesteps
|
| 76 |
|
|
|
|
| 81 |
sample: torch.Tensor,
|
| 82 |
next_timestep: torch.Tensor,
|
| 83 |
return_dict: bool = True,
|
| 84 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 85 |
) -> LightningDiTFlowMatchSchedulerOutput:
|
| 86 |
+
del generator
|
| 87 |
sample_dtype = sample.dtype
|
| 88 |
sample = sample.to(dtype=torch.float64)
|
| 89 |
model_output = model_output.to(dtype=torch.float64)
|
|
|
|
| 103 |
sample: torch.Tensor,
|
| 104 |
next_timestep: torch.Tensor,
|
| 105 |
return_dict: bool = True,
|
| 106 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 107 |
) -> LightningDiTFlowMatchSchedulerOutput:
|
| 108 |
+
del generator
|
| 109 |
sample_dtype = sample.dtype
|
| 110 |
sample = sample.to(dtype=torch.float64)
|
| 111 |
model_output = model_output.to(dtype=torch.float64)
|