| """Sage-T2I inference pipeline — real 1024x1024 generation with pos_embed interpolation.""" |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from diffusers import AutoencoderKL |
| from .dit import DiT |
| from .config import DiTConfig |
|
|
| class SageT2IPipeline: |
| def __init__(self, model_path=None, device="cpu", dtype=torch.float32): |
| self.device = torch.device(device) |
| self.dtype = dtype |
| self.config = DiTConfig(hidden_size=384, num_layers=12, num_heads=6, image_size=128) |
|
|
| self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(self.device) |
| self.vae.eval() |
|
|
| self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype).to(self.device) |
| self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| self.text_encoder.eval() |
|
|
| self.dit = DiT(self.config).to(self.device, dtype=dtype) |
| if model_path: |
| sd = torch.load(model_path, map_location=self.device, weights_only=True) |
| self.dit.load_state_dict(sd) |
| self.dit.eval() |
|
|
| |
| self.base_latent_size = self.config.image_size // 8 |
| self.base_pos_embed = self.dit.pos_embed.data.clone() |
|
|
| def _interpolate_pos_embed(self, new_latent_size): |
| """Interpolate positional embeddings to support higher/lower resolution generation.""" |
| old_size = self.base_latent_size |
| if new_latent_size == old_size: |
| self.dit.pos_embed.data.copy_(self.base_pos_embed) |
| return |
| old_patches = old_size // self.dit.patch_size |
| new_patches = new_latent_size // self.dit.patch_size |
| pe = self.base_pos_embed.float() |
| pe = pe.reshape(1, old_patches, old_patches, -1).permute(0, 3, 1, 2) |
| pe = F.interpolate(pe, size=(new_patches, new_patches), mode="bicubic", align_corners=False) |
| pe = pe.permute(0, 2, 3, 1).reshape(1, new_patches * new_patches, -1) |
| self.dit.pos_embed.data.copy_(pe.to(self.dtype)) |
|
|
| @torch.no_grad() |
| def __call__(self, prompt, num_steps=50, cfg_scale=4.0, seed=None, |
| output_size=1024): |
| """Generate image. output_size: pixel size (must be multiple of 8).""" |
| if seed is not None: |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| tokens = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") |
| text_emb = self.text_encoder(tokens.input_ids.to(self.device))[0] |
|
|
| latent_size = output_size // 8 |
| self._interpolate_pos_embed(latent_size) |
| latents = torch.randn(1, self.config.in_channels, latent_size, latent_size, |
| device=self.device, dtype=self.dtype) |
|
|
| timesteps = torch.linspace(1.0, 0.0, num_steps, device=self.device) |
| dt = 1.0 / num_steps |
|
|
| for i, t in enumerate(timesteps): |
| t_batch = t.expand(1) |
| pred = self.dit(latents, t_batch, text_emb) |
|
|
| if cfg_scale > 1.0: |
| null_emb = torch.zeros_like(text_emb) |
| pred_uncond = self.dit(latents, t_batch, null_emb) |
| pred = pred_uncond + cfg_scale * (pred - pred_uncond) |
|
|
| latents = latents + dt * pred |
|
|
| latents = latents / 0.18215 |
| image = self.vae.decode(latents).sample |
| image = (image.clamp(-1, 1) + 1) / 2 |
| image = image.cpu().permute(0, 2, 3, 1).numpy() |
| image = (image * 255).clip(0, 255).astype(np.uint8) |
| return image[0] |
|
|
| def to(self, device): |
| self.device = torch.device(device) |
| self.vae.to(device) |
| self.text_encoder.to(device) |
| self.dit.to(device) |
| return self |
|
|