"""Sage-T2I inference pipeline — real 1024x1024 generation with pos_embed interpolation.""" import torch import torch.nn.functional as F import numpy as np import math from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL from .dit import DiT from .config import DiTConfig class SageT2IPipeline: def __init__(self, model_path=None, device="cpu", dtype=torch.float32): self.device = torch.device(device) self.dtype = dtype self.config = DiTConfig(hidden_size=384, num_layers=12, num_heads=6, image_size=128) self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(self.device) self.vae.eval() self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype).to(self.device) self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.text_encoder.eval() self.dit = DiT(self.config).to(self.device, dtype=dtype) if model_path: sd = torch.load(model_path, map_location=self.device, weights_only=True) self.dit.load_state_dict(sd) self.dit.eval() # Store original pos_embed for interpolation self.base_latent_size = self.config.image_size // 8 self.base_pos_embed = self.dit.pos_embed.data.clone() def _interpolate_pos_embed(self, new_latent_size): """Interpolate positional embeddings to support higher/lower resolution generation.""" old_size = self.base_latent_size if new_latent_size == old_size: self.dit.pos_embed.data.copy_(self.base_pos_embed) return old_patches = old_size // self.dit.patch_size new_patches = new_latent_size // self.dit.patch_size pe = self.base_pos_embed.float() # (1, N, D) pe = pe.reshape(1, old_patches, old_patches, -1).permute(0, 3, 1, 2) # (1, D, H, W) pe = F.interpolate(pe, size=(new_patches, new_patches), mode="bicubic", align_corners=False) pe = pe.permute(0, 2, 3, 1).reshape(1, new_patches * new_patches, -1) self.dit.pos_embed.data.copy_(pe.to(self.dtype)) @torch.no_grad() def __call__(self, prompt, num_steps=50, cfg_scale=4.0, seed=None, output_size=1024): """Generate image. output_size: pixel size (must be multiple of 8).""" if seed is not None: torch.manual_seed(seed) np.random.seed(seed) tokens = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") text_emb = self.text_encoder(tokens.input_ids.to(self.device))[0] latent_size = output_size // 8 self._interpolate_pos_embed(latent_size) latents = torch.randn(1, self.config.in_channels, latent_size, latent_size, device=self.device, dtype=self.dtype) timesteps = torch.linspace(1.0, 0.0, num_steps, device=self.device) dt = 1.0 / num_steps for i, t in enumerate(timesteps): t_batch = t.expand(1) pred = self.dit(latents, t_batch, text_emb) if cfg_scale > 1.0: null_emb = torch.zeros_like(text_emb) pred_uncond = self.dit(latents, t_batch, null_emb) pred = pred_uncond + cfg_scale * (pred - pred_uncond) latents = latents + dt * pred latents = latents / 0.18215 image = self.vae.decode(latents).sample image = (image.clamp(-1, 1) + 1) / 2 image = image.cpu().permute(0, 2, 3, 1).numpy() image = (image * 255).clip(0, 255).astype(np.uint8) return image[0] def to(self, device): self.device = torch.device(device) self.vae.to(device) self.text_encoder.to(device) self.dit.to(device) return self