Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import BaseOutput, deprecate | |
| class NewbiePipelineOutput(BaseOutput): | |
| images: List["PIL.Image.Image"] | |
| latents: Optional[torch.Tensor] = None | |
| class NewbiePipeline(DiffusionPipeline): | |
| """ | |
| NewBie image pipeline (NextDiT + Gemma3 + JinaCLIP + FLUX VAE). | |
| - Transformer: `NextDiT_3B_GQA_patch2_Adaln_Refiner_WHIT_CLIP` | |
| - Scheduler: `FlowMatchEulerDiscreteScheduler` | |
| - VAE: FLUX-style `AutoencoderKL` with scale/shift | |
| - Text encoder: Gemma3 (from 🤗 Transformers) | |
| - CLIP encoder: JinaCLIPModel (from 🤗 Transformers, ``trust_remote_code=True``) | |
| """ | |
| model_cpu_offload_seq = "text_encoder->clip_model->transformer->vae" | |
| def __init__( | |
| self, | |
| transformer, | |
| text_encoder, | |
| tokenizer, | |
| clip_model, | |
| clip_tokenizer, | |
| vae, | |
| scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, | |
| ): | |
| super().__init__() | |
| if scheduler is None: | |
| scheduler = FlowMatchEulerDiscreteScheduler() | |
| self.register_modules( | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| clip_model=clip_model, | |
| clip_tokenizer=clip_tokenizer, | |
| vae=vae, | |
| scheduler=scheduler, | |
| ) | |
| # --------------------------------------------------------------------- | |
| # helpers | |
| # --------------------------------------------------------------------- | |
| def _get_vae_scale_shift(self) -> Tuple[float, float]: | |
| config = getattr(self.vae, "config", None) | |
| scale = getattr(config, "scaling_factor", None) | |
| shift = getattr(config, "shift_factor", None) | |
| if scale is None: | |
| scale = 0.3611 | |
| if shift is None: | |
| shift = 0.1159 | |
| return float(scale), float(shift) | |
| def _prepare_latents( | |
| self, | |
| batch_size: int, | |
| height: int, | |
| width: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| latent_h, latent_w = height // 8, width // 8 | |
| shape = (batch_size, 16, latent_h, latent_w) | |
| if latents is not None: | |
| if latents.shape != shape: | |
| raise ValueError( | |
| f"Unexpected latents shape, got {latents.shape}, expected {shape}." | |
| ) | |
| return latents.to(device=device, dtype=dtype) | |
| if isinstance(generator, list): | |
| if len(generator) != batch_size: | |
| raise ValueError( | |
| f"Got a list of {len(generator)} generators, but batch_size={batch_size}." | |
| ) | |
| latents = torch.stack( | |
| [ | |
| torch.randn(shape[1:], generator=g, device=device, dtype=dtype) | |
| for g in generator | |
| ], | |
| dim=0, | |
| ) | |
| else: | |
| latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
| return latents | |
| def _encode_prompt( | |
| self, | |
| prompts: List[str], | |
| clip_captions: Optional[List[str]] = None, | |
| max_length: int = 512, | |
| clip_max_length: int = 512, | |
| ) -> Tuple[ | |
| torch.Tensor, | |
| torch.Tensor, | |
| Optional[torch.Tensor], | |
| Optional[torch.Tensor], | |
| Optional[torch.Tensor], | |
| ]: | |
| if clip_captions is None: | |
| clip_captions = prompts | |
| # Gemma tokenizer + encoder | |
| text_inputs = self.tokenizer( | |
| prompts, | |
| padding=True, | |
| pad_to_multiple_of=8, | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| input_ids = text_inputs.input_ids.to(self.text_encoder.device) | |
| attn_mask = text_inputs.attention_mask.to(self.text_encoder.device) | |
| enc_out = self.text_encoder( | |
| input_ids=input_ids, | |
| attention_mask=attn_mask, | |
| output_hidden_states=True, | |
| ) | |
| cap_feats = enc_out.hidden_states[-2] | |
| cap_mask = attn_mask | |
| # Jina CLIP encoding | |
| clip_inputs = self.clip_tokenizer( | |
| clip_captions, | |
| padding=True, | |
| truncation=True, | |
| max_length=clip_max_length, | |
| return_tensors="pt", | |
| ).to(self.clip_model.device) | |
| clip_feats = self.clip_model.get_text_features(input_ids=clip_inputs) | |
| clip_text_pooled: Optional[torch.Tensor] = None | |
| clip_text_sequence: Optional[torch.Tensor] = None | |
| if isinstance(clip_feats, (tuple, list)) and len(clip_feats) == 2: | |
| clip_text_pooled, clip_text_sequence = clip_feats | |
| else: | |
| clip_text_pooled = clip_feats | |
| if clip_text_sequence is not None: | |
| clip_text_sequence = clip_text_sequence.clone() | |
| if clip_text_pooled is not None: | |
| clip_text_pooled = clip_text_pooled.clone() | |
| clip_mask = clip_inputs.attention_mask | |
| return cap_feats, cap_mask, clip_text_sequence, clip_text_pooled, clip_mask | |
| # --------------------------------------------------------------------- | |
| # main call | |
| # --------------------------------------------------------------------- | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]], | |
| negative_prompt: Optional[Union[str, List[str]]] = "", | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_inference_steps: int = 28, | |
| guidance_scale: float = 5.0, | |
| cfg_trunc: float = 1.0, | |
| renorm_cfg: bool = True, | |
| system_prompt: str = "", | |
| num_images_per_prompt: int = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| output_type: str = "pil", | |
| return_dict: bool = True, | |
| return_latents: bool = False, | |
| **kwargs, | |
| ) -> Union[NewbiePipelineOutput, Tuple[List["PIL.Image.Image"], torch.Tensor]]: | |
| if isinstance(prompt, str): | |
| batch_size = 1 | |
| prompts = [prompt] | |
| else: | |
| prompts = list(prompt) | |
| batch_size = len(prompts) | |
| if negative_prompt is None: | |
| negative_prompt = "" | |
| if isinstance(negative_prompt, str): | |
| neg_prompts = [negative_prompt] * batch_size | |
| else: | |
| neg_prompts = list(negative_prompt) | |
| if len(neg_prompts) != batch_size: | |
| raise ValueError( | |
| "negative_prompt must have same batch size as prompt when provided as a list." | |
| ) | |
| if num_images_per_prompt != 1: | |
| deprecate( | |
| "num_images_per_prompt!=1 for NewbiePipeline", | |
| "0.31.0", | |
| "The Newbie architecture currently assumes num_images_per_prompt == 1.", | |
| ) | |
| clip_captions_pos = prompts | |
| clip_captions_neg = neg_prompts | |
| if system_prompt: | |
| prompts_for_gemma = [system_prompt + p for p in prompts] | |
| neg_for_gemma = [system_prompt + p if p else "" for p in neg_prompts] | |
| else: | |
| prompts_for_gemma = prompts | |
| neg_for_gemma = neg_prompts | |
| device = self._execution_device | |
| dtype = self.transformer.dtype | |
| latents = self._prepare_latents( | |
| batch_size=batch_size, | |
| height=height, | |
| width=width, | |
| dtype=dtype, | |
| device=device, | |
| generator=generator, | |
| latents=latents, | |
| ) | |
| latents = latents.to(device=device, dtype=dtype) | |
| latents = latents.repeat(2, 1, 1, 1) # [2B, C, H, W] | |
| full_gemma_prompts = prompts_for_gemma + neg_for_gemma | |
| full_clip_captions = clip_captions_pos + clip_captions_neg | |
| cap_feats, cap_mask, clip_text_sequence, clip_text_pooled, clip_mask = self._encode_prompt( | |
| full_gemma_prompts, | |
| clip_captions=full_clip_captions, | |
| ) | |
| cap_feats = cap_feats.to(device=device, dtype=dtype) | |
| cap_mask = cap_mask.to(device) | |
| if clip_text_sequence is not None: | |
| clip_text_sequence = clip_text_sequence.to(device=device, dtype=dtype) | |
| if clip_text_pooled is not None: | |
| clip_text_pooled = clip_text_pooled.to(device=device, dtype=dtype) | |
| model_kwargs: Dict[str, Any] = dict( | |
| cap_feats=cap_feats, | |
| cap_mask=cap_mask, | |
| cfg_scale=float(guidance_scale), | |
| cfg_trunc=float(cfg_trunc), | |
| renorm_cfg=bool(renorm_cfg), | |
| clip_text_sequence=clip_text_sequence, | |
| clip_text_pooled=clip_text_pooled, | |
| clip_img_pooled=None, | |
| ) | |
| self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| for t in timesteps: | |
| timestep = t | |
| noise_pred = self.transformer.forward_with_cfg( | |
| latents, | |
| timestep, | |
| **model_kwargs, | |
| ) | |
| noise_pred = -noise_pred | |
| latents = self.scheduler.step( | |
| model_output=noise_pred, | |
| timestep=timestep, | |
| sample=latents, | |
| return_dict=False, | |
| )[0] | |
| latents_out = latents[:batch_size] | |
| # 7. VAE decode | |
| vae_scale, vae_shift = self._get_vae_scale_shift() | |
| decoded = self.vae.decode(latents_out / vae_scale + vae_shift).sample | |
| images = (decoded / 2 + 0.5).clamp(0, 1) | |
| if output_type == "pil": | |
| import numpy as np | |
| from PIL import Image | |
| images_np = images.detach().float().cpu() | |
| images_np = images_np.permute(0, 2, 3, 1).numpy() | |
| images_np = (images_np * 255).round().astype(np.uint8) | |
| images_out = [Image.fromarray(img) for img in images_np] | |
| else: | |
| images_out = images | |
| if not return_dict: | |
| return images_out, (latents_out if return_latents else None) | |
| return NewbiePipelineOutput( | |
| images=images_out, | |
| latents=latents_out if return_latents else None, | |
| ) | |