sage-t2i / model /pipeline.py
itriedcoding's picture
Upload folder using huggingface_hub
2d7087a verified
Raw
History Blame Contribute Delete
3.88 kB
"""Sage-T2I inference pipeline — real 1024x1024 generation with pos_embed interpolation."""
import torch
import torch.nn.functional as F
import numpy as np
import math
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from .dit import DiT
from .config import DiTConfig
class SageT2IPipeline:
def __init__(self, model_path=None, device="cpu", dtype=torch.float32):
self.device = torch.device(device)
self.dtype = dtype
self.config = DiTConfig(hidden_size=384, num_layers=12, num_heads=6, image_size=128)
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(self.device)
self.vae.eval()
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype).to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.text_encoder.eval()
self.dit = DiT(self.config).to(self.device, dtype=dtype)
if model_path:
sd = torch.load(model_path, map_location=self.device, weights_only=True)
self.dit.load_state_dict(sd)
self.dit.eval()
# Store original pos_embed for interpolation
self.base_latent_size = self.config.image_size // 8
self.base_pos_embed = self.dit.pos_embed.data.clone()
def _interpolate_pos_embed(self, new_latent_size):
"""Interpolate positional embeddings to support higher/lower resolution generation."""
old_size = self.base_latent_size
if new_latent_size == old_size:
self.dit.pos_embed.data.copy_(self.base_pos_embed)
return
old_patches = old_size // self.dit.patch_size
new_patches = new_latent_size // self.dit.patch_size
pe = self.base_pos_embed.float() # (1, N, D)
pe = pe.reshape(1, old_patches, old_patches, -1).permute(0, 3, 1, 2) # (1, D, H, W)
pe = F.interpolate(pe, size=(new_patches, new_patches), mode="bicubic", align_corners=False)
pe = pe.permute(0, 2, 3, 1).reshape(1, new_patches * new_patches, -1)
self.dit.pos_embed.data.copy_(pe.to(self.dtype))
@torch.no_grad()
def __call__(self, prompt, num_steps=50, cfg_scale=4.0, seed=None,
output_size=1024):
"""Generate image. output_size: pixel size (must be multiple of 8)."""
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
tokens = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
text_emb = self.text_encoder(tokens.input_ids.to(self.device))[0]
latent_size = output_size // 8
self._interpolate_pos_embed(latent_size)
latents = torch.randn(1, self.config.in_channels, latent_size, latent_size,
device=self.device, dtype=self.dtype)
timesteps = torch.linspace(1.0, 0.0, num_steps, device=self.device)
dt = 1.0 / num_steps
for i, t in enumerate(timesteps):
t_batch = t.expand(1)
pred = self.dit(latents, t_batch, text_emb)
if cfg_scale > 1.0:
null_emb = torch.zeros_like(text_emb)
pred_uncond = self.dit(latents, t_batch, null_emb)
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
latents = latents + dt * pred
latents = latents / 0.18215
image = self.vae.decode(latents).sample
image = (image.clamp(-1, 1) + 1) / 2
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).clip(0, 255).astype(np.uint8)
return image[0]
def to(self, device):
self.device = torch.device(device)
self.vae.to(device)
self.text_encoder.to(device)
self.dit.to(device)
return self