madtune commited on
Commit
10ba1f7
·
verified ·
1 Parent(s): d9d6c63

Delete pixeldit/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixeldit/pipeline.py +0 -81
pixeldit/pipeline.py DELETED
@@ -1,81 +0,0 @@
1
- """
2
- PixelDiT T2I Pipeline — thin orchestrator.
3
-
4
- Usage:
5
- from pipeline import PixelDiTPipeline
6
- pipe = PixelDiTPipeline()
7
- images = pipe("a viking at sunset", height=512, width=512)
8
- images[0].save("out.jpg")
9
- """
10
-
11
- import torch
12
- from PIL import Image
13
- from .modeling_pixeldit import load_pixeldit
14
- from .modeling_pixeldit_hf import PixelDiTModel
15
- from .text_encoder_gemma import GemmaEncoder
16
- from .text_encoder_qwen import QwenEncoder
17
- from .scheduling_flow import FlowScheduler
18
-
19
-
20
- class PixelDiTPipeline:
21
- def __init__(
22
- self,
23
- text_encoder="gemma", # "gemma" | "qwen"
24
- qwen_proj=None,
25
- device="cuda",
26
- dtype=torch.bfloat16,
27
- cfg=3.5,
28
- flow_shift=4.0,
29
- pretrained=None, # HF dir or repo id — loads via from_pretrained instead of .pth
30
- ):
31
- self.device = torch.device(device)
32
- self.dtype = dtype
33
-
34
- if text_encoder == "qwen":
35
- self.encoder = QwenEncoder(proj_path=qwen_proj, output_device=device, output_dtype=dtype)
36
- else:
37
- self.encoder = GemmaEncoder(output_device=device, output_dtype=dtype)
38
-
39
- if pretrained is not None:
40
- print(f"[pipeline] loading from HF: {pretrained}")
41
- self.model = (
42
- PixelDiTModel.from_pretrained(pretrained)
43
- .to(device).to(dtype).eval()
44
- )
45
- else:
46
- self.model = load_pixeldit(device=device, dtype=dtype)
47
-
48
- self.scheduler = FlowScheduler(self.model, cfg=cfg, flow_shift=flow_shift)
49
-
50
- @torch.no_grad()
51
- def __call__(
52
- self,
53
- prompt,
54
- negative_prompt="",
55
- height=512,
56
- width=512,
57
- steps=20,
58
- cfg=None,
59
- seed=None,
60
- ):
61
- if isinstance(prompt, str):
62
- prompts = [prompt]
63
- else:
64
- prompts = list(prompt)
65
- B = len(prompts)
66
-
67
- if cfg is not None:
68
- self.scheduler.cfg = cfg
69
- if seed is not None:
70
- torch.manual_seed(seed)
71
-
72
- cond = self.encoder.encode(prompts)
73
- uncond = (self.encoder.encode_null(B) if not negative_prompt
74
- else self.encoder.encode([negative_prompt] * B))
75
-
76
- noise = torch.randn(B, 3, height, width, device=self.device, dtype=self.dtype)
77
- imgs = self.scheduler.sample(noise, cond, uncond, steps=steps)
78
-
79
- imgs = (imgs.clamp(-1, 1) + 1) / 2
80
- imgs = (imgs * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
81
- return [Image.fromarray(img) for img in imgs]