Alexander Bagus commited on
Commit
cd08558
·
1 Parent(s): c46c37a
Files changed (2) hide show
  1. custom/pipeline_newbie.py +321 -0
  2. requirements.txt +1 -1
custom/pipeline_newbie.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.utils import BaseOutput, deprecate
10
+
11
+
12
+ @dataclass
13
+ class NewbiePipelineOutput(BaseOutput):
14
+ images: List["PIL.Image.Image"]
15
+ latents: Optional[torch.Tensor] = None
16
+
17
+
18
+ class NewbiePipeline(DiffusionPipeline):
19
+ """
20
+ NewBie image pipeline (NextDiT + Gemma3 + JinaCLIP + FLUX VAE).
21
+ - Transformer: `NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP`
22
+ - Scheduler: `FlowMatchEulerDiscreteScheduler`
23
+ - VAE: FLUX-style `AutoencoderKL` with scale/shift
24
+ - Text encoder: Gemma3 (from 🤗 Transformers)
25
+ - CLIP encoder: JinaCLIPModel (from 🤗 Transformers, ``trust_remote_code=True``)
26
+ """
27
+
28
+ model_cpu_offload_seq = "text_encoder->clip_model->transformer->vae"
29
+
30
+ def __init__(
31
+ self,
32
+ transformer,
33
+ text_encoder,
34
+ tokenizer,
35
+ clip_model,
36
+ clip_tokenizer,
37
+ vae,
38
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
39
+ ):
40
+ super().__init__()
41
+
42
+ if scheduler is None:
43
+ scheduler = FlowMatchEulerDiscreteScheduler()
44
+
45
+ self.register_modules(
46
+ transformer=transformer,
47
+ text_encoder=text_encoder,
48
+ tokenizer=tokenizer,
49
+ clip_model=clip_model,
50
+ clip_tokenizer=clip_tokenizer,
51
+ vae=vae,
52
+ scheduler=scheduler,
53
+ )
54
+
55
+ # ---------------------------------------------------------------------
56
+ # helpers
57
+ # ---------------------------------------------------------------------
58
+
59
+ def _get_vae_scale_shift(self) -> Tuple[float, float]:
60
+ config = getattr(self.vae, "config", None)
61
+ scale = getattr(config, "scaling_factor", None)
62
+ shift = getattr(config, "shift_factor", None)
63
+
64
+ if scale is None:
65
+ scale = 0.3611
66
+ if shift is None:
67
+ shift = 0.1159
68
+
69
+ return float(scale), float(shift)
70
+
71
+ def _prepare_latents(
72
+ self,
73
+ batch_size: int,
74
+ height: int,
75
+ width: int,
76
+ dtype: torch.dtype,
77
+ device: torch.device,
78
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
79
+ latents: Optional[torch.Tensor] = None,
80
+ ) -> torch.Tensor:
81
+ latent_h, latent_w = height // 8, width // 8
82
+ shape = (batch_size, 16, latent_h, latent_w)
83
+
84
+ if latents is not None:
85
+ if latents.shape != shape:
86
+ raise ValueError(
87
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}."
88
+ )
89
+ return latents.to(device=device, dtype=dtype)
90
+
91
+ if isinstance(generator, list):
92
+ if len(generator) != batch_size:
93
+ raise ValueError(
94
+ f"Got a list of {len(generator)} generators, but batch_size={batch_size}."
95
+ )
96
+ latents = torch.stack(
97
+ [
98
+ torch.randn(shape[1:], generator=g, device=device, dtype=dtype)
99
+ for g in generator
100
+ ],
101
+ dim=0,
102
+ )
103
+ else:
104
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
105
+
106
+ return latents
107
+
108
+ @torch.no_grad()
109
+ def _encode_prompt(
110
+ self,
111
+ prompts: List[str],
112
+ clip_captions: Optional[List[str]] = None,
113
+ max_length: int = 512,
114
+ clip_max_length: int = 512,
115
+ ) -> Tuple[
116
+ torch.Tensor,
117
+ torch.Tensor,
118
+ Optional[torch.Tensor],
119
+ Optional[torch.Tensor],
120
+ Optional[torch.Tensor],
121
+ ]:
122
+ if clip_captions is None:
123
+ clip_captions = prompts
124
+
125
+ # Gemma tokenizer + encoder
126
+ text_inputs = self.tokenizer(
127
+ prompts,
128
+ padding=True,
129
+ pad_to_multiple_of=8,
130
+ max_length=max_length,
131
+ truncation=True,
132
+ return_tensors="pt",
133
+ )
134
+ input_ids = text_inputs.input_ids.to(self.text_encoder.device)
135
+ attn_mask = text_inputs.attention_mask.to(self.text_encoder.device)
136
+
137
+ enc_out = self.text_encoder(
138
+ input_ids=input_ids,
139
+ attention_mask=attn_mask,
140
+ output_hidden_states=True,
141
+ )
142
+ cap_feats = enc_out.hidden_states[-2]
143
+ cap_mask = attn_mask
144
+
145
+ # Jina CLIP encoding
146
+ clip_inputs = self.clip_tokenizer(
147
+ clip_captions,
148
+ padding=True,
149
+ truncation=True,
150
+ max_length=clip_max_length,
151
+ return_tensors="pt",
152
+ ).to(self.clip_model.device)
153
+
154
+ clip_feats = self.clip_model.get_text_features(input_ids=clip_inputs)
155
+
156
+ clip_text_pooled: Optional[torch.Tensor] = None
157
+ clip_text_sequence: Optional[torch.Tensor] = None
158
+
159
+ if isinstance(clip_feats, (tuple, list)) and len(clip_feats) == 2:
160
+ clip_text_pooled, clip_text_sequence = clip_feats
161
+ else:
162
+ clip_text_pooled = clip_feats
163
+
164
+ if clip_text_sequence is not None:
165
+ clip_text_sequence = clip_text_sequence.clone()
166
+ if clip_text_pooled is not None:
167
+ clip_text_pooled = clip_text_pooled.clone()
168
+
169
+ clip_mask = clip_inputs.attention_mask
170
+
171
+ return cap_feats, cap_mask, clip_text_sequence, clip_text_pooled, clip_mask
172
+
173
+ # ---------------------------------------------------------------------
174
+ # main call
175
+ # ---------------------------------------------------------------------
176
+
177
+ @torch.no_grad()
178
+ def __call__(
179
+ self,
180
+ prompt: Union[str, List[str]],
181
+ negative_prompt: Optional[Union[str, List[str]]] = "",
182
+ height: int = 1024,
183
+ width: int = 1024,
184
+ num_inference_steps: int = 28,
185
+ guidance_scale: float = 5.0,
186
+ cfg_trunc: float = 1.0,
187
+ renorm_cfg: bool = True,
188
+ system_prompt: str = "",
189
+ num_images_per_prompt: int = 1,
190
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
191
+ latents: Optional[torch.Tensor] = None,
192
+ output_type: str = "pil",
193
+ return_dict: bool = True,
194
+ return_latents: bool = False,
195
+ **kwargs,
196
+ ) -> Union[NewbiePipelineOutput, Tuple[List["PIL.Image.Image"], torch.Tensor]]:
197
+
198
+
199
+ if isinstance(prompt, str):
200
+ batch_size = 1
201
+ prompts = [prompt]
202
+ else:
203
+ prompts = list(prompt)
204
+ batch_size = len(prompts)
205
+
206
+ if negative_prompt is None:
207
+ negative_prompt = ""
208
+ if isinstance(negative_prompt, str):
209
+ neg_prompts = [negative_prompt] * batch_size
210
+ else:
211
+ neg_prompts = list(negative_prompt)
212
+ if len(neg_prompts) != batch_size:
213
+ raise ValueError(
214
+ "negative_prompt must have same batch size as prompt when provided as a list."
215
+ )
216
+
217
+ if num_images_per_prompt != 1:
218
+ deprecate(
219
+ "num_images_per_prompt!=1 for NewbiePipeline",
220
+ "0.31.0",
221
+ "The Newbie architecture currently assumes num_images_per_prompt == 1.",
222
+ )
223
+
224
+ clip_captions_pos = prompts
225
+ clip_captions_neg = neg_prompts
226
+
227
+ if system_prompt:
228
+ prompts_for_gemma = [system_prompt + p for p in prompts]
229
+ neg_for_gemma = [system_prompt + p if p else "" for p in neg_prompts]
230
+ else:
231
+ prompts_for_gemma = prompts
232
+ neg_for_gemma = neg_prompts
233
+
234
+ device = self._execution_device
235
+ dtype = self.transformer.dtype
236
+
237
+ latents = self._prepare_latents(
238
+ batch_size=batch_size,
239
+ height=height,
240
+ width=width,
241
+ dtype=dtype,
242
+ device=device,
243
+ generator=generator,
244
+ latents=latents,
245
+ )
246
+ latents = latents.to(device=device, dtype=dtype)
247
+ latents = latents.repeat(2, 1, 1, 1) # [2B, C, H, W]
248
+
249
+ full_gemma_prompts = prompts_for_gemma + neg_for_gemma
250
+ full_clip_captions = clip_captions_pos + clip_captions_neg
251
+
252
+ cap_feats, cap_mask, clip_text_sequence, clip_text_pooled, clip_mask = self._encode_prompt(
253
+ full_gemma_prompts,
254
+ clip_captions=full_clip_captions,
255
+ )
256
+
257
+ cap_feats = cap_feats.to(device=device, dtype=dtype)
258
+ cap_mask = cap_mask.to(device)
259
+ if clip_text_sequence is not None:
260
+ clip_text_sequence = clip_text_sequence.to(device=device, dtype=dtype)
261
+ if clip_text_pooled is not None:
262
+ clip_text_pooled = clip_text_pooled.to(device=device, dtype=dtype)
263
+
264
+ model_kwargs: Dict[str, Any] = dict(
265
+ cap_feats=cap_feats,
266
+ cap_mask=cap_mask,
267
+ cfg_scale=float(guidance_scale),
268
+ cfg_trunc=float(cfg_trunc),
269
+ renorm_cfg=bool(renorm_cfg),
270
+ clip_text_sequence=clip_text_sequence,
271
+ clip_text_pooled=clip_text_pooled,
272
+ clip_img_pooled=None,
273
+ )
274
+
275
+ self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device)
276
+ timesteps = self.scheduler.timesteps
277
+
278
+ for t in timesteps:
279
+ timestep = t
280
+
281
+ noise_pred = self.transformer.forward_with_cfg(
282
+ latents,
283
+ timestep,
284
+ **model_kwargs,
285
+ )
286
+
287
+ noise_pred = -noise_pred
288
+
289
+ latents = self.scheduler.step(
290
+ model_output=noise_pred,
291
+ timestep=timestep,
292
+ sample=latents,
293
+ return_dict=False,
294
+ )[0]
295
+
296
+ latents_out = latents[:batch_size]
297
+
298
+ # 7. VAE decode
299
+ vae_scale, vae_shift = self._get_vae_scale_shift()
300
+ decoded = self.vae.decode(latents_out / vae_scale + vae_shift).sample
301
+ images = (decoded / 2 + 0.5).clamp(0, 1)
302
+
303
+ if output_type == "pil":
304
+ import numpy as np
305
+ from PIL import Image
306
+
307
+ images_np = images.detach().float().cpu()
308
+ images_np = images_np.permute(0, 2, 3, 1).numpy()
309
+ images_np = (images_np * 255).round().astype(np.uint8)
310
+ images_out = [Image.fromarray(img) for img in images_np]
311
+ else:
312
+ images_out = images
313
+
314
+ if not return_dict:
315
+ return images_out, (latents_out if return_latents else None)
316
+
317
+ return NewbiePipelineOutput(
318
+ images=images_out,
319
+ latents=latents_out if return_latents else None,
320
+ )
321
+
requirements.txt CHANGED
@@ -3,4 +3,4 @@ torch
3
  transformers
4
  accelerate
5
  spaces
6
- https://github.com/E-Anlia/diffusers/archive/refs/heads/add-newbie-pipeline.zip
 
3
  transformers
4
  accelerate
5
  spaces
6
+ diffusers