BiliSakura commited on
Commit
fdc4fb0
·
verified ·
1 Parent(s): 775db63

Update LightningDit-XL-1-256/scheduler/scheduling_flow_match_lightningdit.py

Browse files
LightningDit-XL-1-256/scheduler/scheduling_flow_match_lightningdit.py CHANGED
@@ -4,7 +4,7 @@
4
  # you may not use this file except in compliance with the License.
5
 
6
  from dataclasses import dataclass
7
- from typing import Optional, Tuple
8
 
9
  import torch
10
 
@@ -43,11 +43,17 @@ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
43
  order = 1
44
 
45
  @register_to_config
46
- def __init__(self, path_type: str = "linear", num_train_timesteps: int = 1000):
 
 
 
 
 
47
  if path_type not in {"linear", "cosine"}:
48
  raise ValueError("path_type must be either 'linear' or 'cosine'.")
49
  self.path_type = path_type
50
  self.num_train_timesteps = num_train_timesteps
 
51
  self.timesteps = torch.linspace(0.0, 1.0, num_train_timesteps + 1, dtype=torch.float64)
52
 
53
  @staticmethod
@@ -60,10 +66,11 @@ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
60
  self,
61
  num_inference_steps: int,
62
  device: Optional[torch.device] = None,
63
- timestep_shift: float = 0.0,
64
  ):
 
65
  timesteps = torch.linspace(0.0, 1.0, num_inference_steps + 1, dtype=torch.float64)
66
- timesteps = self._apply_timestep_shift(timesteps, timestep_shift)
67
  self.timesteps = timesteps.to(device=device)
68
  return self.timesteps
69
 
@@ -74,7 +81,9 @@ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
74
  sample: torch.Tensor,
75
  next_timestep: torch.Tensor,
76
  return_dict: bool = True,
 
77
  ) -> LightningDiTFlowMatchSchedulerOutput:
 
78
  sample_dtype = sample.dtype
79
  sample = sample.to(dtype=torch.float64)
80
  model_output = model_output.to(dtype=torch.float64)
@@ -94,7 +103,9 @@ class LightningDiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
94
  sample: torch.Tensor,
95
  next_timestep: torch.Tensor,
96
  return_dict: bool = True,
 
97
  ) -> LightningDiTFlowMatchSchedulerOutput:
 
98
  sample_dtype = sample.dtype
99
  sample = sample.to(dtype=torch.float64)
100
  model_output = model_output.to(dtype=torch.float64)
 
4
  # you may not use this file except in compliance with the License.
5
 
6
  from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple, Union
8
 
9
  import torch
10
 
 
43
  order = 1
44
 
45
  @register_to_config
46
+ def __init__(
47
+ self,
48
+ path_type: str = "linear",
49
+ num_train_timesteps: int = 1000,
50
+ shift: float = 0.3,
51
+ ):
52
  if path_type not in {"linear", "cosine"}:
53
  raise ValueError("path_type must be either 'linear' or 'cosine'.")
54
  self.path_type = path_type
55
  self.num_train_timesteps = num_train_timesteps
56
+ self.shift = shift
57
  self.timesteps = torch.linspace(0.0, 1.0, num_train_timesteps + 1, dtype=torch.float64)
58
 
59
  @staticmethod
 
66
  self,
67
  num_inference_steps: int,
68
  device: Optional[torch.device] = None,
69
+ timestep_shift: Optional[float] = None,
70
  ):
71
+ shift = self.shift if timestep_shift is None else timestep_shift
72
  timesteps = torch.linspace(0.0, 1.0, num_inference_steps + 1, dtype=torch.float64)
73
+ timesteps = self._apply_timestep_shift(timesteps, shift)
74
  self.timesteps = timesteps.to(device=device)
75
  return self.timesteps
76
 
 
81
  sample: torch.Tensor,
82
  next_timestep: torch.Tensor,
83
  return_dict: bool = True,
84
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
85
  ) -> LightningDiTFlowMatchSchedulerOutput:
86
+ del generator
87
  sample_dtype = sample.dtype
88
  sample = sample.to(dtype=torch.float64)
89
  model_output = model_output.to(dtype=torch.float64)
 
103
  sample: torch.Tensor,
104
  next_timestep: torch.Tensor,
105
  return_dict: bool = True,
106
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
107
  ) -> LightningDiTFlowMatchSchedulerOutput:
108
+ del generator
109
  sample_dtype = sample.dtype
110
  sample = sample.to(dtype=torch.float64)
111
  model_output = model_output.to(dtype=torch.float64)