cuio
/

MiniT2I / pipeline.py
cuio's picture
Duplicate from MiniT2I/MiniT2I
32813be
Raw
History Blame Contribute Delete
28.6 kB
import math
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
def modulate(x, shift, scale):
return x * (1 + scale[:, None, :]) + shift[:, None, :]
def rotate_half(x):
x1, x2 = x.reshape(*x.shape[:-1], 2, -1).unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
y = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return y * self.weight
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, t):
half = self.frequency_embedding_size // 2
freqs = torch.exp(
-math.log(10000.0)
* torch.arange(half, device=t.device, dtype=torch.float32)
/ half
)
args = t.float()[:, None] * freqs[None]
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return self.mlp(emb.to(dtype=self.mlp[0].weight.dtype))
class BottleneckPatchEmbed(nn.Module):
def __init__(self, img_size=512, patch_size=16, in_channels=3, pca_channels=128, hidden_size=1248):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.proj1 = nn.Conv2d(in_channels, pca_channels, kernel_size=patch_size, stride=patch_size, bias=False)
self.proj2 = nn.Conv2d(pca_channels, hidden_size, kernel_size=1, stride=1, bias=True)
def forward(self, x):
x = self.proj2(self.proj1(x))
return x.flatten(2).transpose(1, 2)
class SwiGLUMlp(nn.Module):
def __init__(self, in_features: int, hidden_features: int):
super().__init__()
hidden_dim = (hidden_features + 7) // 8 * 8
self.w1 = nn.Linear(in_features, hidden_dim, bias=False)
self.w3 = nn.Linear(in_features, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, in_features, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TextRotaryEmbedding1D(nn.Module):
def __init__(self, head_dim: int, theta: float = 10000.0):
super().__init__()
self.head_dim = head_dim
self.theta = theta
def forward(self, x):
b, length, h, d = x.shape
inv = 1.0 / (self.theta ** (torch.arange(0, d, 2, device=x.device, dtype=torch.float32) / d))
pos = torch.arange(length, device=x.device, dtype=torch.float32)
angles = torch.einsum("l,f->lf", pos, inv)
angles = torch.cat([angles, angles], dim=-1)
cos = angles.cos().to(dtype=x.dtype)
sin = angles.sin().to(dtype=x.dtype)
return x * cos[None, :, None, :] + rotate_half(x) * sin[None, :, None, :]
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(self, head_dim: int, theta: float = 10000.0):
super().__init__()
self.dim = head_dim // 2
self.theta = theta
def forward(self, x):
length = x.shape[1]
side = int(math.sqrt(length))
if side * side != length:
raise ValueError(f"image token length must be square, got {length}")
freqs = 1.0 / (
self.theta
** (torch.arange(0, self.dim, 2, device=x.device, dtype=torch.float32)[: self.dim // 2] / self.dim)
)
t = torch.arange(side, device=x.device, dtype=torch.float32)
base = torch.einsum("l,f->lf", t, freqs)
f_h, f_w = torch.broadcast_tensors(base[:, None, :], base[None, :, :])
angles = torch.cat([f_h, f_w], dim=-1)
angles = torch.cat([angles, angles], dim=-1).reshape(length, -1)
cos = angles.cos().to(dtype=x.dtype)
sin = angles.sin().to(dtype=x.dtype)
return x * cos[None, :, None, :] + rotate_half(x) * sin[None, :, None, :]
class MultiModalRotaryEmbeddingFast(nn.Module):
def __init__(self, head_dim: int):
super().__init__()
self.text_rope = TextRotaryEmbedding1D(head_dim)
self.vision_rope = VisionRotaryEmbeddingFast(head_dim)
def forward(self, x, txt_len: int):
txt = self.text_rope(x[:, :txt_len])
img = self.vision_rope(x[:, txt_len:])
return torch.cat([txt, img], dim=1)
class PlainTextTransformerBlock(nn.Module):
def __init__(self, hidden_size=1248, num_heads=24, head_dim=52, mlp_ratio=2.7):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
inner_dim = num_heads * head_dim
self.norm1 = RMSNorm(hidden_size)
self.norm2 = RMSNorm(hidden_size)
self.qkv = nn.Linear(hidden_size, inner_dim * 3)
self.attn_proj = nn.Linear(inner_dim, hidden_size)
self.mlp = SwiGLUMlp(hidden_size, int(hidden_size * mlp_ratio))
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
self.rope = TextRotaryEmbedding1D(head_dim)
def forward(self, txt):
b, length, _ = txt.shape
qkv = self.qkv(self.norm1(txt)).reshape(b, length, 3, self.num_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
q = self.rope(self.q_norm(q))
k = self.rope(self.k_norm(k))
attn = torch.einsum("bqhd,bkhd->bhqk", q, k) * (self.head_dim ** -0.5)
out = torch.einsum("bhqk,bkhd->bqhd", attn.softmax(dim=-1), v).reshape(b, length, -1)
txt = txt + self.attn_proj(out)
txt = txt + self.mlp(self.norm2(txt))
return txt
class DoubleStreamDiTBlock(nn.Module):
def __init__(self, hidden_size=1248, txt_hidden_size=1248, num_heads=24, head_dim=52, mlp_ratio=2.7):
super().__init__()
self.hidden_size = hidden_size
self.txt_hidden_size = txt_hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
inner_dim = num_heads * head_dim
self.img_norm1 = RMSNorm(hidden_size)
self.img_norm2 = RMSNorm(hidden_size)
self.txt_norm1 = RMSNorm(txt_hidden_size)
self.txt_norm2 = RMSNorm(txt_hidden_size)
self.img_qkv = nn.Linear(hidden_size, inner_dim * 3)
self.txt_qkv = nn.Linear(txt_hidden_size, inner_dim * 3)
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
self.rope = MultiModalRotaryEmbeddingFast(head_dim)
self.img_attn_proj = nn.Linear(inner_dim, hidden_size)
self.txt_attn_proj = nn.Linear(inner_dim, txt_hidden_size)
self.img_mlp = SwiGLUMlp(hidden_size, int(hidden_size * mlp_ratio))
self.txt_mlp = SwiGLUMlp(txt_hidden_size, int(txt_hidden_size * mlp_ratio))
def forward(self, x, txt, vec):
b, li, _ = x.shape
lt = txt.shape[1]
x_norm = self.img_norm1(x)
txt_norm = self.txt_norm1(txt)
qkv_i = self.img_qkv(x_norm).reshape(b, li, 3, self.num_heads, self.head_dim)
qkv_t = self.txt_qkv(txt_norm).reshape(b, lt, 3, self.num_heads, self.head_dim)
q_i, k_i, v_i = qkv_i[:, :, 0], qkv_i[:, :, 1], qkv_i[:, :, 2]
q_t, k_t, v_t = qkv_t[:, :, 0], qkv_t[:, :, 1], qkv_t[:, :, 2]
q_i, k_i = self.q_norm(q_i), self.k_norm(k_i)
q_t, k_t = self.q_norm(q_t), self.k_norm(k_t)
q = self.rope(torch.cat([q_t, q_i], dim=1), txt_len=lt)
k = self.rope(torch.cat([k_t, k_i], dim=1), txt_len=lt)
v = torch.cat([v_t, v_i], dim=1)
attn = torch.einsum("bqhd,bkhd->bhqk", q, k) * (self.head_dim ** -0.5)
out = torch.einsum("bhqk,bkhd->bqhd", attn.softmax(dim=-1), v)
x = x + self.img_attn_proj(out[:, lt:].reshape(b, li, -1))
txt = txt + self.txt_attn_proj(out[:, :lt].reshape(b, lt, -1))
x = x + self.img_mlp(self.img_norm2(x))
txt = txt + self.txt_mlp(self.txt_norm2(txt))
return x, txt
class FinalLayer(nn.Module):
def __init__(self, hidden_size=1248, patch_size=16, out_channels=3):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels
self.norm_final = RMSNorm(hidden_size)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
def forward(self, x, vec=None):
return self.linear(self.norm_final(x))
def get_2d_sincos_pos_embed(embed_dim, grid_size, device, dtype):
grid_h = torch.arange(grid_size, device=device, dtype=torch.float32)
grid_w = torch.arange(grid_size, device=device, dtype=torch.float32)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0).reshape(2, 1, grid_size, grid_size)
emb_h = get_1d_sincos_pos_embed(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed(embed_dim // 2, grid[1])
return torch.cat([emb_h, emb_w], dim=1).to(dtype=dtype)
def get_1d_sincos_pos_embed(embed_dim, pos):
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float32)
omega = 1.0 / (10000 ** (omega / (embed_dim / 2.0)))
out = torch.einsum("m,d->md", pos.reshape(-1), omega)
return torch.cat([out.sin(), out.cos()], dim=1)
@dataclass
class MMJiTConfig:
image_size: int = 512
patch_size: int = 16
in_channels: int = 3
txt_input_size: int = 1024
hidden_size: int = 768
txt_hidden_size: int = 768
cond_vec_size: int = 768
depth_double: int = 17
txt_preamble_depth: int = 2
num_heads: int = 12
head_dim: int = 64
mlp_ratio: float = 2.6667
pca_channels: int = 128
prompt_length: int = 256
n_T: int = 100
prediction: str = "x"
sampler: str = "euler"
cfg_channels: int = 3
cfg_interval: tuple = (0.0, 1.0)
llm: str = "google/flan-t5-large"
class MMJiT(nn.Module):
def __init__(self, cfg: MMJiTConfig):
super().__init__()
self.cfg = cfg
self.latent_img_size = cfg.image_size // cfg.patch_size
self.img_embedder = BottleneckPatchEmbed(
cfg.image_size, cfg.patch_size, cfg.in_channels, cfg.pca_channels, cfg.hidden_size
)
self.txt_embedder = nn.Linear(cfg.txt_input_size, cfg.txt_hidden_size, bias=False)
self.mask_token = nn.Parameter(torch.zeros(1, 1, cfg.txt_input_size))
self.t_embedder = TimestepEmbedder(cfg.cond_vec_size)
self.pooled_embedder = nn.Linear(cfg.txt_input_size, cfg.cond_vec_size, bias=False)
self.txt_preamble_blocks = nn.ModuleList(
[PlainTextTransformerBlock(cfg.txt_hidden_size, cfg.num_heads, cfg.head_dim, cfg.mlp_ratio) for _ in range(cfg.txt_preamble_depth)]
)
self.double_blocks = nn.ModuleList(
[
DoubleStreamDiTBlock(
cfg.hidden_size, cfg.txt_hidden_size, cfg.num_heads, cfg.head_dim, cfg.mlp_ratio
)
for _ in range(cfg.depth_double)
]
)
self.final_layer = FinalLayer(cfg.hidden_size, cfg.patch_size, cfg.in_channels)
def unpatchify(self, x):
b = x.shape[0]
p = self.cfg.patch_size
c = self.cfg.in_channels
h = w = int(math.sqrt(x.shape[1]))
x = x.reshape(b, h, w, p, p, c)
x = torch.einsum("nhwpqc->nchpwq", x)
return x.reshape(b, c, h * p, w * p)
def forward(self, img, t, context, attn_mask):
if img.ndim == 4 and img.shape[1] != self.cfg.in_channels:
img = img.permute(0, 3, 1, 2)
attn_mask = attn_mask.to(device=context.device)
context = torch.where(attn_mask[:, :, None] > 0.5, context, self.mask_token.to(dtype=context.dtype))
x = self.img_embedder(img)
pos = get_2d_sincos_pos_embed(self.cfg.hidden_size, self.latent_img_size, x.device, x.dtype)
x = x + pos[None]
t_vec = self.t_embedder(t)
txt = self.txt_embedder(context.to(dtype=self.txt_embedder.weight.dtype))
pooled_text = context.mean(dim=1)
vec = t_vec + self.pooled_embedder(pooled_text.to(dtype=self.pooled_embedder.weight.dtype))
for block in self.txt_preamble_blocks:
txt = block(txt)
for block in self.double_blocks:
x, txt = block(x, txt, vec)
combined = torch.cat([txt, x], dim=1)
out = self.final_layer(combined, vec)
img_out = out[:, txt.shape[1] :, :]
return self.unpatchify(img_out)
class DiffusionModel(nn.Module):
def __init__(self, cfg: Optional[MMJiTConfig] = None):
super().__init__()
self.cfg = cfg or MMJiTConfig()
self.net = MMJiT(self.cfg)
def real_t_to_embed_t(self, t):
return t
def pred_velocity(self, x, t, text, mask):
x0 = self.net(x, self.real_t_to_embed_t(t), text, mask)
return (x0 - x) / torch.clamp(1 - t[:, None, None, None], min=0.001)
def cfg_velocity(self, x, t, text, mask, cfg_scale: float):
b = x.shape[0]
xx = torch.cat([x, x], dim=0)
tt = torch.cat([t, t], dim=0)
yy = torch.cat([text, text], dim=0)
mm = torch.cat([mask, torch.zeros_like(mask)], dim=0)
out = self.pred_velocity(xx, tt, yy, mm)
cond, uncond = out[:b], out[b:]
use_cfg = ((t >= self.cfg.cfg_interval[0]) & (t <= self.cfg.cfg_interval[1])).to(out.dtype)
scale = torch.where(use_cfg[:, None, None, None] > 0, torch.tensor(cfg_scale, device=x.device, dtype=out.dtype), torch.tensor(1.0, device=x.device, dtype=out.dtype))
return uncond + (cond - uncond) * scale
@torch.no_grad()
def sample(self, text, mask, cfg_scale=6.0, generator=None, progress=False):
b = text.shape[0]
device = text.device
dtype = next(self.parameters()).dtype
x = torch.randn(
b, self.cfg.in_channels, self.cfg.image_size, self.cfg.image_size,
generator=generator, device=device, dtype=dtype,
) * 2
timesteps = torch.linspace(0.0, 1.0, self.cfg.n_T + 1, device=device, dtype=dtype)
iterator = range(self.cfg.n_T)
if progress:
from tqdm.auto import tqdm
iterator = tqdm(iterator)
for i in iterator:
t_cur = timesteps[i].expand(b)
t_next = timesteps[i + 1].expand(b)
v = self.cfg_velocity(x, t_cur, text.to(dtype), mask.to(dtype), cfg_scale)
x = x + (t_next - t_cur)[:, None, None, None] * v
return x
import os
from dataclasses import asdict
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Union
os.environ.setdefault("USE_FLAX", "0")
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
import torch
from PIL import Image
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, T5EncoderModel
from transformers import logging as transformers_logging
from diffusers import DiffusionPipeline, ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.schedulers.scheduling_utils import SchedulerMixin
transformers_logging.set_verbosity_error()
class MiniT2IFlowMatchScheduler(SchedulerMixin, ConfigMixin):
config_name = "scheduler_config.json"
@register_to_config
def __init__(
self,
train_t_schedule: str = "lognorm",
t_lognorm_mu: float = -0.8,
t_lognorm_sigma: float = 0.8,
num_inference_steps: int = 100,
):
if train_t_schedule not in {"uniform", "lognorm"}:
raise ValueError(f"Unsupported train_t_schedule: {train_t_schedule}")
def sample_train_timesteps(self, batch_size, device, dtype=torch.float32, generator=None):
if self.config.train_t_schedule == "uniform":
return torch.rand(batch_size, device=device, dtype=dtype, generator=generator)
normal = torch.randn(batch_size, device=device, dtype=torch.float32, generator=generator)
normal = normal * self.config.t_lognorm_sigma + self.config.t_lognorm_mu
return torch.sigmoid(normal).to(dtype=dtype)
def get_inference_timesteps(self, num_inference_steps=None, device=None, dtype=torch.float32):
steps = int(num_inference_steps or self.config.num_inference_steps)
return torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
class MiniT2IMMJiTModel(ModelMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
image_size: int = 512,
patch_size: int = 16,
in_channels: int = 3,
txt_input_size: int = 1024,
hidden_size: int = 768,
txt_hidden_size: int = 768,
cond_vec_size: int = 768,
depth_double: int = 17,
txt_preamble_depth: int = 2,
num_heads: int = 12,
head_dim: int = 64,
mlp_ratio: float = 2.6666666666666665,
pca_channels: int = 128,
prompt_length: int = 256,
n_T: int = 100,
prediction: str = "x",
sampler: str = "euler",
cfg_channels: int = 3,
cfg_interval: tuple = (0.0, 1.0),
llm: str = "google/flan-t5-large",
):
super().__init__()
cfg = MMJiTConfig(
image_size=image_size,
patch_size=patch_size,
in_channels=in_channels,
txt_input_size=txt_input_size,
hidden_size=hidden_size,
txt_hidden_size=txt_hidden_size,
cond_vec_size=cond_vec_size,
depth_double=depth_double,
txt_preamble_depth=txt_preamble_depth,
num_heads=num_heads,
head_dim=head_dim,
mlp_ratio=mlp_ratio,
pca_channels=pca_channels,
prompt_length=prompt_length,
n_T=n_T,
prediction=prediction,
sampler=sampler,
cfg_channels=cfg_channels,
cfg_interval=tuple(cfg_interval),
llm=llm,
)
self.model = DiffusionModel(cfg)
@property
def mmjit_config(self) -> MMJiTConfig:
return self.model.cfg
def forward(self, img, t, context, attn_mask):
return self.model.net(img, t, context, attn_mask)
def pred_velocity(self, x, t, text, mask):
return self.model.pred_velocity(x, t, text, mask)
def sample(self, text, mask, cfg_scale=6.0, generator=None, progress=False):
return self.model.sample(text, mask, cfg_scale=cfg_scale, generator=generator, progress=progress)
class MiniT2ITextToImagePipeline(nn.Module):
def __init__(
self,
transformer: MiniT2IMMJiTModel,
scheduler: Optional[MiniT2IFlowMatchScheduler] = None,
tokenizer=None,
text_encoder=None,
text_encoder_name: str = "google/flan-t5-large",
train_t_schedule: str = "lognorm",
t_lognorm_mu: float = -0.8,
t_lognorm_sigma: float = 0.8,
num_inference_steps: int = 100,
):
super().__init__()
if not isinstance(scheduler, MiniT2IFlowMatchScheduler):
scheduler = MiniT2IFlowMatchScheduler(
train_t_schedule=train_t_schedule,
t_lognorm_mu=t_lognorm_mu,
t_lognorm_sigma=t_lognorm_sigma,
num_inference_steps=num_inference_steps,
)
self.transformer = transformer
self.scheduler = scheduler
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.config = SimpleNamespace(
text_encoder_name=text_encoder_name,
train_t_schedule=scheduler.config.train_t_schedule,
t_lognorm_mu=scheduler.config.t_lognorm_mu,
t_lognorm_sigma=scheduler.config.t_lognorm_sigma,
num_inference_steps=scheduler.config.num_inference_steps,
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
torch_dtype: Optional[torch.dtype] = None,
text_encoder_dtype: torch.dtype = torch.float32,
local_files_only: bool = False,
revision: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
**kwargs,
):
root = Path(pretrained_model_name_or_path)
if not root.exists():
root = Path(
snapshot_download(
repo_id=str(pretrained_model_name_or_path),
revision=revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
)
transformer = MiniT2IMMJiTModel.from_pretrained(root / "transformer", torch_dtype=torch_dtype, **kwargs)
scheduler_dir = root / "scheduler"
if scheduler_dir.exists():
scheduler = MiniT2IFlowMatchScheduler.from_pretrained(scheduler_dir)
else:
scheduler = MiniT2IFlowMatchScheduler()
text_encoder_name = transformer.mmjit_config.llm
tokenizer = AutoTokenizer.from_pretrained(text_encoder_name, local_files_only=local_files_only)
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_name,
torch_dtype=text_encoder_dtype,
local_files_only=local_files_only,
)
return cls(
transformer=transformer,
scheduler=scheduler,
tokenizer=tokenizer,
text_encoder=text_encoder,
text_encoder_name=text_encoder_name,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
self.transformer.save_pretrained(save_directory / "transformer", **kwargs)
self.scheduler.save_pretrained(save_directory / "scheduler")
def _encode_prompt(self, prompt: Union[str, List[str]], device):
if isinstance(prompt, str):
prompt = [prompt]
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_name)
if self.text_encoder is None:
self.text_encoder = T5EncoderModel.from_pretrained(self.config.text_encoder_name)
if next(self.text_encoder.parameters()).device != device:
self.text_encoder.to(device)
cfg = self.transformer.mmjit_config
tokens = self.tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=cfg.prompt_length,
)
input_ids = tokens.input_ids.to(device)
attn = tokens.attention_mask.to(device)
text = self.text_encoder(input_ids=input_ids, attention_mask=attn).last_hidden_state
return text, attn
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
guidance_scale: float = 6.0,
num_inference_steps: Optional[int] = None,
generator: Optional[torch.Generator] = None,
output_type: str = "pil",
return_dict: bool = True,
progress: bool = True,
):
device = next(self.transformer.parameters()).device
if isinstance(prompt, str):
prompt_batch = [prompt] * num_images_per_prompt
else:
prompt_batch = []
for p in prompt:
prompt_batch.extend([p] * num_images_per_prompt)
old_steps = self.transformer.mmjit_config.n_T
self.transformer.model.cfg.n_T = int(num_inference_steps or self.scheduler.config.num_inference_steps)
try:
text, attn = self._encode_prompt(prompt_batch, device)
model_dtype = next(self.transformer.parameters()).dtype
images = self.transformer.sample(
text.to(dtype=model_dtype),
attn.to(dtype=model_dtype),
cfg_scale=guidance_scale,
generator=generator,
progress=progress,
)
finally:
self.transformer.model.cfg.n_T = old_steps
images = (images.clamp(-1, 1) * 127.5 + 128.0).clamp(0, 255).to(torch.uint8)
images = images.permute(0, 2, 3, 1).cpu().numpy()
if output_type == "pil":
images = [Image.fromarray(image) for image in images]
if not return_dict:
return (images,)
return ImagePipelineOutput(images=images)
class MiniT2IPipeline(DiffusionPipeline):
MODEL_ALIASES = {
"b": "minit2i-b-16",
"b16": "minit2i-b-16",
"b-16": "minit2i-b-16",
"base": "minit2i-b-16",
"minit2i-b16": "minit2i-b-16",
"minit2i-b-16": "minit2i-b-16",
"minit2i-b/16": "minit2i-b-16",
"l": "minit2i-l-16",
"l16": "minit2i-l-16",
"l-16": "minit2i-l-16",
"large": "minit2i-l-16",
"minit2i-l16": "minit2i-l-16",
"minit2i-l-16": "minit2i-l-16",
"minit2i-l/16": "minit2i-l-16",
}
def __init__(self):
super().__init__()
@classmethod
def _resolve_model_type(cls, model_type: str) -> str:
key = model_type.lower().replace("_", "-")
if key not in cls.MODEL_ALIASES:
choices = ", ".join(sorted(set(cls.MODEL_ALIASES)))
raise ValueError(f"Unknown model_type={model_type!r}. Expected one of: {choices}")
return cls.MODEL_ALIASES[key]
@staticmethod
def _resolve_root(
repo_id_or_path: Union[str, os.PathLike],
model_dir: str,
revision: Optional[str],
cache_dir: Optional[Union[str, os.PathLike]],
local_files_only: bool,
) -> Path:
root = Path(repo_id_or_path)
if root.exists():
return root
return Path(
snapshot_download(
repo_id=str(repo_id_or_path),
revision=revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
allow_patterns=[
f"{model_dir}/transformer/*",
f"{model_dir}/scheduler/*",
],
)
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
model_type: str = "b16",
repo_id_or_path: Union[str, os.PathLike] = "MiniT2I/MiniT2I",
torch_dtype: Optional[torch.dtype] = torch.bfloat16,
text_encoder_dtype: torch.dtype = torch.float32,
device: Optional[Union[str, torch.device]] = None,
local_files_only: bool = False,
revision: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
**kwargs,
):
model_dir = self._resolve_model_type(model_type)
root = self._resolve_root(repo_id_or_path, model_dir, revision, cache_dir, local_files_only)
model_root = root / model_dir
transformer = MiniT2IMMJiTModel.from_pretrained(model_root / "transformer", torch_dtype=torch_dtype)
scheduler = MiniT2IFlowMatchScheduler.from_pretrained(model_root / "scheduler")
text_encoder_name = transformer.mmjit_config.llm
tokenizer = AutoTokenizer.from_pretrained(text_encoder_name, local_files_only=local_files_only)
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_name,
torch_dtype=text_encoder_dtype,
local_files_only=local_files_only,
)
pipe = MiniT2ITextToImagePipeline(
transformer=transformer,
scheduler=scheduler,
tokenizer=tokenizer,
text_encoder=text_encoder,
text_encoder_name=text_encoder_name,
)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
return pipe(prompt=prompt, **kwargs)
def build_transformer_from_checkpoint(ckpt_path: Union[str, os.PathLike]) -> MiniT2IMMJiTModel:
payload = torch.load(ckpt_path, map_location="cpu")
cfg = MMJiTConfig(**payload["config"])
transformer = MiniT2IMMJiTModel(**asdict(cfg))
prefixed = payload["state_dict"]
state_dict = {}
for key, value in prefixed.items():
if key.startswith("net."):
state_dict[f"model.{key}"] = value
else:
state_dict[f"model.{key}"] = value
transformer.load_state_dict(state_dict, strict=True)
return transformer