madtune commited on
Commit
5204049
·
verified ·
1 Parent(s): 10ba1f7

Delete pixeldit/scheduling_flow.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixeldit/scheduling_flow.py +0 -65
pixeldit/scheduling_flow.py DELETED
@@ -1,65 +0,0 @@
1
- """
2
- Flow-matching DPM-Solver++ sampler for PixelDiT.
3
-
4
- Wraps the original DPMS from the PixelDiT repo.
5
- Order=2 multistep gets quality at 20 steps that Euler needs 100+ for.
6
-
7
- Usage:
8
- from scheduling_flow import FlowScheduler
9
-
10
- scheduler = FlowScheduler(model_fn, cfg=3.5, flow_shift=4.0)
11
- image = scheduler.sample(noise, cond, uncond, steps=20)
12
- """
13
-
14
- import sys
15
- import torch
16
- from tqdm import tqdm
17
-
18
- sys.path.insert(0, "/home/nobus/Raid0/PixelDiT/t2i")
19
- from diffusion.model.flow_dpm import DPMS
20
-
21
- _FLOW_SHIFT = 4.0 # 1024px stage-3 config
22
-
23
-
24
- class FlowScheduler:
25
- def __init__(self, model_fn, cfg=3.5, flow_shift=_FLOW_SHIFT):
26
- """
27
- model_fn: callable(x, t, y) -> velocity [B,3,H,W]
28
- cfg: classifier-free guidance scale
29
- """
30
- # DPMS passes y as [B,1,L,D] but PixDiT_T2I expects [B,L,D] — squeeze here
31
- self.model_fn = lambda x, t, y: model_fn(x, t, y.squeeze(1) if y.dim() == 4 else y)
32
- self.cfg = cfg
33
- self.flow_shift = flow_shift
34
-
35
- @torch.no_grad()
36
- def sample(
37
- self,
38
- noise: torch.Tensor, # [B, 3, H, W] Gaussian noise
39
- cond: torch.Tensor, # [B, 300, 2304]
40
- uncond: torch.Tensor, # [B, 300, 2304]
41
- steps: int = 20,
42
- ) -> torch.Tensor:
43
- """Returns denoised image tensor [B, 3, H, W] in [-1, 1]."""
44
- # DPMS expects [B, 1, L, D]
45
- cond_4d = cond.unsqueeze(1)
46
- uncond_4d = uncond.unsqueeze(1)
47
-
48
- dpm = DPMS(
49
- self.model_fn,
50
- condition=cond_4d,
51
- uncondition=uncond_4d,
52
- cfg_scale=self.cfg,
53
- model_type="flow",
54
- schedule="FLOW",
55
- guidance_type="classifier-free",
56
- interval_guidance=[0, 1],
57
- )
58
- return dpm.sample(
59
- noise,
60
- steps=steps,
61
- order=2,
62
- skip_type="time_uniform_flow",
63
- method="multistep",
64
- flow_shift=self.flow_shift,
65
- )