Text-to-Image
English
PAINE / predictor /models /noisear.py
joonghk's picture
first commit
03de09d
"""
NoiseAR: Autoregressive noise generation model.
Extracted from the NoiseAR codebase (try.py / eval.py) for use as an
evaluation baseline. Only the model class is included — no pipeline or
dataset dependencies.
Architecture: patch-based AR transformer that generates latent noise
conditioned on text embeddings. Uses causal self-attention + cross-attention
to text, predicting each noise patch autoregressively.
Supported pipelines: SDXL, DreamShaper (shared weights), DiT (separate weights + dim_adapter).
"""
import math
import torch
import torch.nn as nn
class NoiseARNet(nn.Module):
def __init__(
self,
patch_size: int = 32,
d_model: int = 2048,
d_ff: int = 4096,
n_heads: int = 1,
n_layers: int = 1,
dropout: float = 0.15,
pretrained_path: str = "",
pipeline: str = "SDXL",
n_prediction_head: int = 0,
):
super(NoiseARNet, self).__init__()
self.patch_size = patch_size
self.n_layers = n_layers
self.d_model = d_model
self.pipeline = pipeline
if self.pipeline == "DiT":
self.dim_adapter = nn.Linear(1024, 2048)
self.latent_patch_encoder = nn.Sequential(
nn.Linear(4 * patch_size * patch_size, 2 * patch_size * patch_size),
nn.SiLU(),
nn.Linear(2 * patch_size * patch_size, d_model),
)
self.register_buffer("learnable_start_token_emb", torch.randn(d_model) * 0.1)
self.self_attn_layers = nn.ModuleList(
[
nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True)
for _ in range(n_layers)
]
)
self.self_attn_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_layers)])
self.cross_attn_layers = nn.ModuleList(
[
nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout, batch_first=True)
for _ in range(n_layers)
]
)
self.cross_attn_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_layers)])
self.ffn_layers = nn.ModuleList(
[
nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
for _ in range(n_layers)
]
)
self.ffn_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(n_layers)])
self.prediction_head_mu = nn.Sequential(
nn.Linear(d_model, d_model * 2), nn.GELU(), nn.Linear(d_model * 2, 4 * patch_size * patch_size)
)
self.prediction_head_logvar = nn.Sequential(
nn.Linear(d_model, d_model * 2), nn.GELU(), nn.Linear(d_model * 2, 4 * patch_size * patch_size)
)
if n_prediction_head != 0:
ffn_layers_mu = []
for _ in range(n_prediction_head):
ffn_layers_mu.append(nn.Linear(d_model, d_model * 2))
ffn_layers_mu.append(nn.GELU())
ffn_layers_mu.append(nn.Linear(d_model * 2, d_model))
self.prediction_head_mu = nn.Sequential(
*ffn_layers_mu,
nn.Linear(d_model, 4 * patch_size * patch_size),
)
ffn_layers_logvar = []
for _ in range(n_prediction_head):
ffn_layers_logvar.append(nn.Linear(d_model, d_model * 2))
ffn_layers_logvar.append(nn.GELU())
ffn_layers_logvar.append(nn.Linear(d_model * 2, d_model))
self.prediction_head_logvar = nn.Sequential(
*ffn_layers_logvar,
nn.Linear(d_model, 4 * patch_size * patch_size),
)
nr_patches = 128 * 128 // self.patch_size // self.patch_size
self.seq_len = nr_patches + 1
position = torch.arange(self.seq_len, dtype=torch.float).unsqueeze(1)
pe = torch.zeros(1, self.seq_len, d_model)
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model))
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pos_encoding", pe)
self._init_weights()
if pretrained_path and ".pth" in pretrained_path:
state = torch.load(pretrained_path, weights_only=True)
self.load_state_dict(state, strict=True)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.MultiheadAttention):
nn.init.xavier_uniform_(m.in_proj_weight)
nn.init.xavier_uniform_(m.out_proj.weight)
if m.in_proj_bias is not None:
nn.init.zeros_(m.in_proj_bias)
nn.init.zeros_(m.out_proj.bias)
with torch.no_grad():
if hasattr(self, "prediction_head_mu"):
final_layer = self.prediction_head_mu[-1]
nn.init.uniform_(final_layer.weight, -0.001, 0.001)
nn.init.zeros_(final_layer.bias)
if hasattr(self, "prediction_head_logvar"):
final_layer = self.prediction_head_logvar[-1]
nn.init.uniform_(final_layer.weight, -0.001, 0.001)
nn.init.zeros_(final_layer.bias)
@staticmethod
def patchfy_to_origin(x, patch_size):
shp = x.shape
x = x.reshape([shp[0], 128 // patch_size, 128 // patch_size, -1, patch_size, patch_size])
x = x.permute(0, 3, 1, 4, 2, 5).reshape(shp[0], -1, 128, 128)
return x
@staticmethod
def origin_to_patchfy(x, patch_size):
b, c, h, w = x.shape
x = x.reshape([b, c, h // patch_size, patch_size, w // patch_size, patch_size])
x = x.permute(0, 2, 4, 1, 3, 5).reshape(b, -1, c * patch_size * patch_size)
return x
@staticmethod
def sample_noise(mu, std):
eps = torch.randn_like(std)
return mu + eps * std
def forward_network(self, x, causal_mask, y):
for layer_id in range(self.n_layers):
x = self.self_attn_norm[layer_id](x)
x = self.self_attn_layers[layer_id](x, x, x, attn_mask=causal_mask)[0] + x
x = self.cross_attn_norm[layer_id](x)
x = self.cross_attn_layers[layer_id](x, y, y)[0] + x
x = self.ffn_norm[layer_id](x)
x = self.ffn_layers[layer_id](x) + x
mu = self.prediction_head_mu(x)
logvar = self.prediction_head_logvar(x)
return mu, logvar
def forward(self, text_emb, x=None):
if self.pipeline == "DiT":
text_emb = self.dim_adapter(text_emb)
if x is None:
return self.inference(text_emb)
x = self.origin_to_patchfy(x, self.patch_size)
x = self.latent_patch_encoder(x)
start_token_emb = self.learnable_start_token_emb[None, None, :].expand(x.shape[0], -1, -1)
patch_sequence = torch.cat([start_token_emb, x], dim=1) + self.pos_encoding
causal_mask = torch.triu(torch.ones(self.seq_len, self.seq_len, device=x.device) * float("-inf"), diagonal=1)
mu, logvar = self.forward_network(patch_sequence, causal_mask, text_emb)
mu = self.patchfy_to_origin(mu[:, :-1, ...], self.patch_size)
logvar = self.patchfy_to_origin(logvar[:, :-1, ...], self.patch_size)
return mu, logvar
def inference(self, text_emb=None):
bs = text_emb.shape[0]
start_token_emb = self.learnable_start_token_emb[None, None, :].expand(bs, -1, -1)
sequence = start_token_emb
sampled_noise_list = []
for patch_id in range(128**2 // self.patch_size // self.patch_size):
pos_enc = self.pos_encoding[:, : sequence.size(1), :]
sequence_with_pos = sequence + pos_enc
causal_mask = torch.triu(
torch.ones(sequence.size(1), sequence.size(1), device=sequence.device) * float("-inf"), diagonal=1
)
mu, logvar = self.forward_network(sequence_with_pos, causal_mask, text_emb)
std = torch.exp(0.5 * torch.clamp(logvar[:, -1, ...], -20, 10))
sampled_noise = self.sample_noise(mu[:, -1, ...], std * 0.5)
sampled_noise_list.append(sampled_noise)
patch_emb = self.latent_patch_encoder(sampled_noise.reshape(bs, 1, -1))
sequence = torch.cat([sequence, patch_emb], dim=1)
return self.patchfy_to_origin(torch.stack(sampled_noise_list, dim=1), self.patch_size)