ConceptAligner / aligner.py
Shaoan's picture
Upload folder using huggingface_hub
ad56805 verified
import torch
from torch import nn
#from refiner import Qwen2Connector
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim=2560, num_heads=20):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Linear projections for Q, K, V
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Output projection
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.scale = self.head_dim ** -0.5
def forward(self, x, mask=None, return_attention=True):
"""
Args:
x: Input tensor of shape [b, seq_len, embed_dim]
mask: Attention mask of shape [b, seq_len], where 1 means attend, 0 means ignore
return_attention: Whether to return attention weights
Returns:
output: [b, seq_len, embed_dim]
attn_weights: [b*num_heads, seq_len, seq_len] (if return_attention=True)
"""
b, seq_len, embed_dim = x.shape
# Project to Q, K, V
Q = self.q_proj(x) # [b, seq_len, embed_dim]
K = self.k_proj(x) # [b, seq_len, embed_dim]
V = self.v_proj(x) # [b, seq_len, embed_dim]
# Reshape and transpose for multi-head attention
# [b, seq_len, embed_dim] -> [b, seq_len, num_heads, head_dim] -> [b, num_heads, seq_len, head_dim]
Q = Q.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Reshape for batch computation: [b, num_heads, seq_len, head_dim] -> [b*num_heads, seq_len, head_dim]
Q = Q.reshape(b * self.num_heads, seq_len, self.head_dim)
K = K.reshape(b * self.num_heads, seq_len, self.head_dim)
V = V.reshape(b * self.num_heads, seq_len, self.head_dim)
# Compute attention scores: Q @ K^T
attn_scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale # [b*num_heads, seq_len, seq_len]
# Apply mask if provided
if mask is not None:
# Key mask (column masking): which keys can be attended to
key_mask = mask.unsqueeze(1).unsqueeze(2) # [b, 1, 1, seq_len]
# Query mask (row masking): which queries are valid
query_mask = mask.unsqueeze(1).unsqueeze(3) # [b, 1, seq_len, 1]
# Combine both masks: a position can attend only if BOTH query and key are valid
# Shape: [b, 1, seq_len, seq_len]
final_mask = query_mask.bool() & key_mask.bool() # Broadcasting handles the dimensions
# Expand to all heads and reshape
final_mask = final_mask.expand(b, self.num_heads, seq_len, seq_len)
final_mask = final_mask.reshape(b * self.num_heads, seq_len, seq_len)
attn_scores = attn_scores.masked_fill(~final_mask, float('-inf'))
# Apply softmax
attn_weights = F.softmax(attn_scores, dim=-1) # [b*num_heads, seq_len, seq_len]
# Handle NaN from softmax (when entire row is -inf)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
# Apply attention to values
attn_output = torch.bmm(attn_weights, V) # [b*num_heads, seq_len, head_dim]
# Reshape back: [b*num_heads, seq_len, head_dim] -> [b, num_heads, seq_len, head_dim]
attn_output = attn_output.view(b, self.num_heads, seq_len, self.head_dim)
# Transpose and reshape: [b, num_heads, seq_len, head_dim] -> [b, seq_len, num_heads, head_dim] -> [b, seq_len, embed_dim]
attn_output = attn_output.transpose(1, 2).contiguous().view(b, seq_len, embed_dim)
# Final output projection
output = self.out_proj(attn_output) # [b, seq_len, embed_dim]
if return_attention:
return output, attn_weights # attn_weights is [b*num_heads, seq_len, seq_len]
else:
return output
class ConceptAligner222(nn.Module):
def __init__(self, custom_pool=1, input_dim=2560, hidden_size=2560):
super().__init__()
if input_dim == 2560:
hidden_size = 2560
self.num_heads = 20
self.model_class = 'gemma3'
depth = 2
identity_mapping = False
elif input_dim == 4096:
hidden_size = 3072
self.num_heads = 24
self.model_class = 't5'
depth = 1
identity_mapping = True
self.text_connector = Qwen2Connector(in_channels=input_dim, hidden_size=hidden_size, heads_num=self.num_heads,
depth=depth, identity_init=identity_mapping)
self.final_proj = nn.Sequential(nn.Linear(hidden_size, 4096), nn.SiLU(), nn.Linear(4096, 4096))
self.resampler = MultiHeadSelfAttention(embed_dim=hidden_size, num_heads=self.num_heads)
empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu')
self.register_buffer('empty_pooled_clip', empty_pooled_clip)
self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1, 1]) * 0.01, requires_grad=True)
self.proj_norm = nn.LayerNorm(hidden_size)
self.custom_pool = custom_pool
if self.custom_pool:
self.clip_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size * 3), nn.SiLU(),
nn.Linear(hidden_size * 3, 768))
self.clip_norm = nn.LayerNorm(768)
print('Using custom pooling for CLIP features.')
@property
def dtype(self):
"""Return the dtype of the model parameters."""
# return next(self.parameters()).dtype
return torch.bfloat16
@property
def device(self):
"""Return the device of the model parameters."""
# return next(self.parameters()).device
return self.empty_pooled_clip.device
def forward(self, text_features, text_mask, is_training=False, img_seq_len=1024):
text_features = self.text_connector(text_features, mask=text_mask,
mean_start_id=2 if self.model_class == 'gemma' else 0)
text_features = self.proj_norm(text_features)
aligned_features, attn = self.resampler(text_features, mask=text_mask, return_attention=True)
if is_training:
learnable_scale = torch.clip(self.learnable_scale_norm, -1.0, 1.0)
visual_concepts = aligned_features + learnable_scale * torch.randn_like(aligned_features)
else:
visual_concepts = aligned_features
prompt_embeds = self.final_proj(visual_concepts)
# prompt_embeds = text_features
if self.custom_pool:
mean_features = (aligned_features * text_mask.unsqueeze(-1)).sum(dim=1) / (
text_mask.sum(dim=1, keepdim=True) + 1e-8)
pooled_prompt_embeds = self.clip_proj(mean_features)
pooled_prompt_embeds = self.clip_norm(pooled_prompt_embeds)
else:
pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1)
dtype = prompt_embeds.dtype
device = prompt_embeds.device
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
total_seq_len = img_seq_len + prompt_embeds.shape[1]
text_seq_len = text_mask.shape[1]
attention_mask = torch.zeros(
len(text_features), 1, 1, total_seq_len,
device=text_mask.device,
dtype=text_mask.dtype
)
# Fill in text portion: where text_mask==0, set to -inf
attention_mask[:, :, :, :text_seq_len] = (1 - text_mask).unsqueeze(1).unsqueeze(2) * -10000.0
entropy = -(attn * torch.log(attn + 1e-8)).sum(dim=-1)
mask_expanded = text_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
mask_expanded = mask_expanded.reshape(len(text_features) * self.num_heads, text_seq_len)
valid_entropy = entropy[mask_expanded.bool()]
return prompt_embeds, attention_mask, pooled_prompt_embeds, text_ids, valid_entropy
# return prompt_embeds, pooled_prompt_embeds, text_ids, None
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int, time_embedding_dim=4096):
super().__init__()
if time_embedding_dim is None:
time_embedding_dim = embedding_dim
self.silu = nn.SiLU()
self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
nn.init.normal_(self.linear.weight, mean=0, std=0.02)
nn.init.zeros_(self.linear.bias)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(
self, x: torch.Tensor, timestep_embedding: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(timestep_embedding))
shift, scale = emb.unsqueeze(1).chunk(2, dim=-1)
x = self.norm(x) * (1 + scale) + shift
return x
class GateMLP(nn.Module):
def __init__(self, gate_mode='soft', input_dim=64, hidden_dim=1024):
super().__init__()
self.gate_mode = gate_mode
hidden_dim = max(input_dim, min(hidden_dim, 512))
hidden_dim = 512
self.input_norm = nn.LayerNorm(4096)
self.norm0 = nn.LayerNorm(input_dim)
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.activation1 = nn.GELU()
self.linear2 = nn.Linear(hidden_dim+4096, hidden_dim)
self.activation2 = nn.GELU()
self.linear3 = nn.Linear(hidden_dim+4096, hidden_dim)
self.activation3 = nn.GELU()
self.final_linear = nn.Linear(hidden_dim, 1)
nn.init.xavier_uniform_(self.linear1.weight)
nn.init.zeros_(self.linear1.bias)
nn.init.xavier_uniform_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
nn.init.xavier_uniform_(self.linear3.weight)
nn.init.zeros_(self.linear3.bias)
nn.init.zeros_(self.final_linear.weight)
bias_val = 0.0 if 'soft' in gate_mode else 1.0
nn.init.constant_(self.final_linear.bias, bias_val)
def forward(self, x):
y = x.transpose(1, 2).flatten(2)
y = self.input_norm(y.detach()).unsqueeze(1).repeat(1, x.shape[1],1,1)
x = self.linear1(self.norm0(x.detach()))
x = self.activation1(x)
x = self.linear2(torch.cat([x, y], dim=-1))
x = self.activation2(x)
x = self.linear3(torch.cat([x,y], dim=-1))
x = self.activation3(x)
x = self.final_linear(x)
return x
class CrossAttentionWithInfluence(nn.Module):
def __init__(self, d_model=4096, num_heads=32, gate_mode='hard'):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.gate_mode = gate_mode
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
# Linear projections for Q, K, V
# self.q_proj = nn.Linear(d_model, d_model)
# self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
# nn.init.normal_(self.q_proj.weight, mean=0, std=0.02)
# nn.init.normal_(self.k_proj.weight, mean=0, std=0.02)
# nn.init.zeros_(self.q_proj.bias)
# nn.init.zeros_(self.k_proj.bias)
nn.init.eye_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
nn.init.eye_(self.v_proj.weight)
nn.init.zeros_(self.v_proj.bias)
self.mask_mlp = GateMLP(input_dim=d_model // num_heads, hidden_dim=1024, gate_mode=gate_mode)
self.scale = self.head_dim ** -0.5
# self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1,1,1])*0.01, requires_grad=True)
self.rec_mlp = nn.Sequential(nn.Linear(4096, 4096), nn.SiLU(),
nn.Linear(4096, 4096), nn.SiLU(),
nn.Linear(4096, 4096)
)
def forward(self, x, y, y_mask, temperature=None, threshold=None, topk=None):
"""
Args:
x: shared embedding [b, 300, 4096]
y: changing embedding [b, 300, 4096]
Returns:
output: [b, 300, 4096]
y_influence: [b, 32, 300, 300] - influence from y to x
"""
b, seq_len_x, d_model = x.shape
b, seq_len_y, d_model_y = y.shape
"""
# Q from x only
Q = self.q_proj(x) # [b, 300, 4096]
seq_len = Q.shape[1]
# K, V from concatenation of [x, y]
K = self.k_proj(x) # [b, 300, 4096]
# Reshape for multi-head attention
Q = Q.view(b, Q.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128]
K = K.view(b, K.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 600, 128]
"""
V = self.v_proj(y) # [b, 300, 4096]
shared_V = self.v_proj(x) # [b, 300, 4096]
textual_concepts = V.view(b, V.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128]
shared_concepts = shared_V.view(b, shared_V.shape[1], self.num_heads, self.head_dim).transpose(1,
2) # [b, 32, 300, 128]
expand_y_mask = y_mask.unsqueeze(1).unsqueeze(-1) # [b, 1, 300, 1]
# Compute attention scores
"""
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [b, 32, 300, 300]
attn_weights = F.softmax(attn_scores, dim=-1) # [b, 32, 300, 300]
# Compute output
attn_output = torch.matmul(attn_weights, textual_concepts) # [b, 32, 300, 128]
"""
diagonal_influence = self.mask_mlp((textual_concepts))
if 'soft' in self.gate_mode:
diagonal_influence = 2 * (torch.sigmoid(diagonal_influence * temperature)) # [b, 32, 300, 1]
diagonal_influence = (diagonal_influence > 0.1).to(
diagonal_influence.dtype) * diagonal_influence # Thresholding
soft_influence = diagonal_influence
else:
soft_influence = torch.sigmoid(diagonal_influence * temperature)
if threshold is None:
threshold = 0.5
else:
print('Using custom threshold for influence gating:', threshold)
hard_influence = (soft_influence >= threshold)
diagonal_influence = hard_influence + soft_influence - soft_influence.detach() # Straight-through estimator
if topk is not None:
print(diagonal_influence.shape, ' <<< shape before topk ')
top_k_values, top_k_indices = torch.topk(diagonal_influence, topk, dim=1)
result = torch.zeros_like(diagonal_influence)
result.scatter_(1, top_k_indices, top_k_values)
diagonal_influence = result
print('Applied top-k sparsification on influence gates with k=', topk)
diagonal_output = textual_concepts * diagonal_influence + shared_concepts * (
1 - diagonal_influence) # [b, 32, 300, 128]
da,db,dc,dd = diagonal_output.shape
rec_diagonal = self.rec_mlp(diagonal_output.transpose(1,2).flatten(2)[y_mask.bool()].to(x.dtype))
tgt_diagonal = y[y_mask.bool()]
diagonal_output = expand_y_mask * diagonal_output + (1 - expand_y_mask) * shared_concepts # [b, 32, 300, 128]
mask_bool_expanded = expand_y_mask.expand_as(diagonal_influence).bool() # [b, 32, 300, 1]
meaningful_gates = diagonal_influence[mask_bool_expanded]
soft_meaningful_gate = soft_influence[mask_bool_expanded]
# full_output = self.learnable_scale_norm*attn_output + diagonal_output # [b, 32, 300, 128]
full_output = diagonal_output.to(x.dtype)
# Reshape back
full_output = full_output.transpose(1, 2).contiguous().view(b, y.shape[1], d_model) # [b, 300, 4096]
full_output = full_output # Residual connection
# Final output projection
output = self.out_proj(full_output) # [b, 300, 4096]
return output, diagonal_influence.squeeze(-1).transpose(1, 2), meaningful_gates, soft_meaningful_gate, rec_diagonal, tgt_diagonal
def init_weights_gaussian(model, mean=0.0, std=0.02):
"""
Initialize all nn.Linear layers in the model:
- weights with Gaussian(mean, std)
- biases to 0
"""
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=mean, std=std)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
class ConceptAligner(nn.Module):
def __init__(self, per_dim=4):
super().__init__()
empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu')
self.register_buffer('empty_pooled_clip', empty_pooled_clip)
test_eps = torch.randn([1, 300, per_dim], dtype=torch.bfloat16).to('cpu')*0.7
self.register_buffer('test_eps', test_eps)
self.init_proj = nn.Sequential(nn.Linear(768, 300*16), nn.SiLU())
self.proj = nn.Sequential(nn.Linear(16, 1024), nn.SiLU(),
nn.Linear(1024, 1024), nn.SiLU())
self.text_proj = nn.Sequential(nn.Linear(4096, 1024), nn.SiLU(),
nn.Linear(1024, 1024), nn.SiLU())
self.proj_mu = nn.Sequential(nn.Linear(1024, per_dim))
self.proj_logvar = nn.Sequential(nn.Linear(1024, per_dim))
self.eps_proj = nn.Sequential(nn.Linear(per_dim, 1024), nn.SiLU(),
nn.LayerNorm(1024),
nn.Linear(1024, 4096))
init_weights_gaussian(self, mean=0.0, std=0.02)
torch.nn.init.constant_(self.eps_proj[-1].weight, 0.0)
torch.nn.init.constant_(self.eps_proj[-1].bias, 0.0)
@property
def dtype(self):
"""Return the dtype of the model parameters."""
# return next(self.parameters()).dtype
return torch.bfloat16
@property
def device(self):
"""Return the device of the model parameters."""
# return next(self.parameters()).device
return self.empty_pooled_clip.device
def forward(self, text_features, image_features=None, eps=None):
#return text_features, None, self.empty_pooled_clip.expand(text_features.shape[0], -1), torch.zeros(text_features.shape[1], 3).to(device=text_features.device, dtype=text_features.dtype), {'mu': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype), 'logvar': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype)}
dtype = text_features.dtype
device = text_features.device
if image_features is not None:
visual_hidden = self.proj(self.init_proj(image_features).view(len(image_features), 300, -1))
text_hidden = self.text_proj(text_features.detach())
hidden = visual_hidden - text_hidden
mu = self.proj_mu(hidden)
logvar = self.proj_logvar(hidden)
eps = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)
else:
if eps is None:
eps = self.test_eps.to(device=device, dtype=dtype)
mu = torch.zeros_like(eps)
logvar = torch.zeros_like(eps)
proj_eps = self.eps_proj(eps)
prompt_embeds = text_features + proj_eps
pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
aux_info = {
'mu': mu,
'logvar': logvar,
'eps': eps
}
return prompt_embeds, None, pooled_prompt_embeds, text_ids, aux_info
if __name__ == '__main__':
from transformers import AutoProcessor
from diffusers import FluxPipeline
import os
from PIL import Image
def create_image_grid(images, cols):
rows = (len(images) + cols - 1) // cols
w, h = images[0].size
grid = Image.new('RGB', (cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img, (i % cols * w, i // cols * h))
return grid
dim = 4096
num_heads = 32
dtype = torch.bfloat16
model = ConceptAligner().to('cuda').to(dtype)
x = torch.randn([5, 300, dim]).to('cuda').to(dtype)
y = torch.randn([5, 300, dim]).to('cuda').to(dtype)
i = torch.randn([5,768]).to('cuda').to(dtype)
y[1] = y[0]
m = torch.ones([5, 300]).to('cuda').to(dtype)
m[:3,:128] = 0
prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(x, i)
print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape)
print(prompt_embeds.shape, ' ', pooled_prompt_embeds.shape, ' ', text_ids.shape)
for k in aux_info:
print(k, aux_info[k].shape, aux_info[k].min(), aux_info[k].max(), aux_info[k].mean())
from text_encoder import LoraT5Embedder
from datasets import load_dataset
dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]')
item = dataset[0:4]
another_item = dataset[0:4]
from diffusers.models.normalization import RMSNorm
clip_processor = AutoProcessor.from_pretrained("./clip-vit-large-patch14")
clip_images = clip_processor(images=item['image'], return_tensors="pt").pixel_values.to('cuda:0').to(dtype)
texts = []
texts.append("""A heartwarming 3D rendered scene of
an elderly farmer and a tiny orange
kitten. The farmer, with a gentle smile,
walks alongside the kitten in a lush,
green garden filled with thriving plants,
showcasing a fruitful harvest. The
intricate details of the overalls and the
farmer's worn, weathered face tell a
story of years spent tending to the land. the farmer is wearing a blue shirt""")
texts.append("""A unique, intricately detailed creature
resembling a reptile, possibly a lizard or
a gecko. It has a vibrant blue and green
scaled body, with large, round, and
expressive eyes that are a deep shade of
blue. The backdrop is a
soft, blurred forest setting, suggesting a
serene and mystical ambiance. the creature is wearing a golden crown""")
texts.append("""Deep in the enchanted forest lives a woman
who is the moon fairy. Her long blonde hair
shines in the starlight, tangled with her flowers
that glow with a soft blue glow. Her eyes are
the color of the night and shine with the magic
of the night. The fairy wears a dress made of
moon petals, woven with threads of moonlight
that shine with an iridescent glow, a crown of
stars adorns her head, shining with the light of
the full moon that illuminates the forest. Her
wings are translucent like glass, with a pale
glow reminiscent of the glow of the moon. HD,
6K, photo, cinematic, poster""")
texts.append(
"""In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""")
text_encoder = LoraT5Embedder(device='cuda').to(dtype)
text_features, _, _, _, image_features, _ = text_encoder(texts, clip_images)
print(text_features.shape, image_features.shape, ' >>>>>>>>> text input')
images = []
pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16)
pipe.to('cuda')
for txt_feat, img_feat in zip(text_features, image_features):
prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(txt_feat.unsqueeze(0), img_feat.unsqueeze(0))
image = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
height=512,
width=512,
guidance_scale=3.5,
num_inference_steps=20,
max_sequence_length=512,
generator=torch.Generator("cuda").manual_seed(1995),
).images[0]
images.append(image)
aligned_image = create_image_grid(images, cols=len(images) // 2)
os.makedirs('samples', exist_ok=True)
aligned_image.save("samples/image%.jpg")
raise SystemExit
influence_matrix = aux_info['influence']
bin_influence_matrix = (influence_matrix > 0.1).float()
mean_alive = bin_influence_matrix.sum(dim=-1).mean()
max_alive = bin_influence_matrix.sum(dim=-1).max()
min_alive = bin_influence_matrix.sum(dim=-1).min()
max_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).max()
mean_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).mean()
min_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).min()
print(
f"Mean alive heads per token: {mean_alive:.2f}, Max alive heads per token: {max_alive:.2f}, Min alive heads per token: {min_alive:.2f}")
print(
f"Mean alive tokens: {mean_token_alive:.2f}, Max alive tokens: {max_token_alive:.2f}, Min alive tokens: {min_token_alive:.2f}")
import os
CHECKPOINT_PATH = 'runs/00393/checkpoint-6000'
from safetensors.torch import load_file
# Load adapter (model.safetensors)
adapter_path = os.path.join(CHECKPOINT_PATH, "model_1.safetensors")
if os.path.exists(adapter_path):
adapter_state = load_file(adapter_path)
model.load_state_dict(adapter_state, strict=True)
print("Adapter loaded successfully!")
print(model.influence_net.v_proj.weight, ' <<< weight ')
print(model.influence_net.v_proj.bias, ' <<< bias ')
print(model.influence_net.out_proj.weight, ' <<< out weight ')
print(model.influence_net.out_proj.bias, ' <<< out bias ')
print(model.influence_net.mask_mlp.linear3.weight, ' <<< gate weight 3 ')
print(model.influence_net.mask_mlp.linear3.bias, ' <<< gate bias ')
z = torch.randn([3, num_heads, 300, 4096 // num_heads]).to('cuda').to(dtype)
gate_values = model.influence_net.mask_mlp(z)
gate_values = 2 * (torch.sigmoid(gate_values))
print(gate_values, ' <<< gate values ', gate_values.shape, ' ', torch.mean(gate_values))
from diffusers import FluxPipeline
from PIL import Image
reserved_memory = torch.cuda.memory_reserved(0) / (1024 ** 3)
print(f"Reserved GPU memory: {reserved_memory:.2f} GB")
from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel
import torch
from text_encoder import LoraT5Embedder
text_encoder = LoraT5Embedder(device='cuda').to(torch.bfloat16)
texts = []
texts.append("""A heartwarming 3D rendered scene of
an elderly farmer and a tiny orange
kitten. The farmer, with a gentle smile,
walks alongside the kitten in a lush,
green garden filled with thriving plants,
showcasing a fruitful harvest. The
intricate details of the overalls and the
farmer's worn, weathered face tell a
story of years spent tending to the land. the farmer is wearing a blue shirt""")
texts.append("""A unique, intricately detailed creature
resembling a reptile, possibly a lizard or
a gecko. It has a vibrant blue and green
scaled body, with large, round, and
expressive eyes that are a deep shade of
blue. The backdrop is a
soft, blurred forest setting, suggesting a
serene and mystical ambiance. the creature is wearing a golden crown""")
texts.append("""Deep in the enchanted forest lives a woman
who is the moon fairy. Her long blonde hair
shines in the starlight, tangled with her flowers
that glow with a soft blue glow. Her eyes are
the color of the night and shine with the magic
of the night. The fairy wears a dress made of
moon petals, woven with threads of moonlight
that shine with an iridescent glow, a crown of
stars adorns her head, shining with the light of
the full moon that illuminates the forest. Her
wings are translucent like glass, with a pale
glow reminiscent of the glow of the moon. HD,
6K, photo, cinematic, poster""")
texts.append(
"""In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""")
texts.append(
"""In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is bare, revealing the natural curve of its muscles and the sheen of its coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and quiet power.""")
INDEX = 0
text = texts[INDEX]
with torch.no_grad():
floral_embeds, _,_,_,_,attn_mask = text_encoder(text, )
print(attn_mask.shape, ' >>>> ', attn_mask)
print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ')
nopad_floral_embeds, nopad_shared_embeds, nopad_attn_mask = text_encoder(text, padding=False)
print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ')
"""
_,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False)
print(aux_info['meaningful_influence'].shape, ' <<< influence shape ', aux_info['meaningful_influence'][:100],' ',torch.mean(aux_info['meaningful_influence']))
floral_embeds, shared_embeds, attn_mask = text_encoder([""], padding='max_length')
_,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False)
print(aux_info['meaningful_influence'].shape, ' <<< empty influence shape ', aux_info['meaningful_influence'],' ',torch.mean(aux_info['meaningful_influence']))
raise SystemExit
"""
text2s = []
text2s.append("""A heartwarming 3D rendered scene of
an elderly farmer and a tiny orange
kitten. The farmer, with a gentle smile,
walks alongside the kitten in a lush,
green garden filled with thriving plants,
showcasing a fruitful harvest. The
intricate details of the overalls and the
farmer's worn, weathered face tell a
story of years spent tending to the land. the farmer is wearing a red shirt""")
text2s.append("""A unique, intricately detailed creature
resembling a reptile, possibly a lizard or
a gecko. It has a vibrant blue and green
scaled body, with large, round, and
expressive eyes that are a deep shade of
blue. The backdrop is a
soft, blurred forest setting, suggesting a
serene and mystical ambiance. the creature is wearing a floral crown""")
text2s.append("""Deep in the enchanted forest lives a woman
who is the moon fairy. Her long black hair
shines in the starlight, tangled with her flowers
that glow with a soft blue glow. Her eyes are
the color of the night and shine with the magic
of the night. The fairy wears a dress made of
moon petals, woven with threads of moonlight
that shine with an iridescent glow, a crown of
stars adorns her head, shining with the light of
the full moon that illuminates the forest. Her
wings are translucent like glass, with a pale
glow reminiscent of the glow of the moon. HD,
6K, photo, cinematic, poster""")
text2s.append(
"""In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a gray, rainy sky outside, with raindrops streaking down the glass and blurred trees beyond. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and introspective, capturing a cozy moment of quiet observation during a rainy afternoon.""")
text2s.append(
"""In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is adorned with a bright red saddle, contrasting sharply against its white coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and a striking touch of color that adds visual drama.""")
text2 = text2s[INDEX]
with torch.no_grad():
golden_embeds, shared_embeds, golden_mask = text_encoder(text2, padding='max_length')
print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ')
nopad_golden_embeds, nopad_shared_embeds, nopad_golden_mask = text_encoder(text2, padding=False)
print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ')
batch_encoding = text_encoder.t5_tokenizer(
text,
truncation=True,
max_length=text_encoder.max_length,
return_tensors="pt",
)
input_ids = batch_encoding["input_ids"][0] # Get the token IDs
# Convert token IDs back to tokens to see what they are
tokens_floral = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids)
batch_encoding = text_encoder.t5_tokenizer(
text2,
truncation=True,
max_length=text_encoder.max_length,
return_tensors="pt",
)
input_ids = batch_encoding["input_ids"][0] # Get the token IDs
tokens_golden = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids)
# Convert token IDs back to tokens to see what they are
# Find the index of specific words
def find_token_indices(tokens, word):
"""Find all indices where a word or its token appears"""
indices = []
# T5 tokenizer might split words or add special characters
word_token = text_encoder.t5_tokenizer.encode(word, add_special_tokens=False)[0]
word_token_str = text_encoder.t5_tokenizer.convert_ids_to_tokens([word_token])[0]
for i, token in enumerate(tokens):
if token == word_token_str or word.lower() in token.lower():
indices.append(i)
return indices
key1s = ['blue', 'golden', 'blonde', 'clear', 'horse']
key2s = ['red', 'floral', 'black', 'rainy', 'red']
# Find indices for "blue"
blue_indices = find_token_indices(tokens_floral, key1s[INDEX])[-1]
print(f"Indices for 'blue': {blue_indices}")
# Find indices for "red" (won't be found in this text)
red_indices = find_token_indices(tokens_golden, key2s[INDEX])[-1]
print(f"Indices for 'red': {red_indices}")
pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16)
pipe.to('cuda')
adapter_path = os.path.join(CHECKPOINT_PATH, "model.safetensors")
if os.path.exists(adapter_path):
adapter_state = load_file(adapter_path)
pipe.transformer.load_state_dict(adapter_state, strict=True)
print("Transformer loaded successfully!")
images = []
empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu').to('cuda').to(torch.bfloat16)
print("Generating image with concatenation...")
images = []
# for cur_prompt_embed in [floral_embeds, nopad_floral_embeds
# , inter_embed, golden_embeds, nopad_golden_embeds]:
# for (start_dim, end_dim) in [(0,4096), (1024,4096), (2048, 4096), (1024, 2048)]:
for emb in ['floral', 'golden']:
for temp in [2.5]:
for thr in [-1, 0.5, 0.75, 0.85, 0.95]:
for topk in [None]:
print('>>>> Temperature: ', temp, topk)
if 'floral' in emb:
inter_embed, _, _, _, new_aux_info = model(floral_embeds, shared_embeds, attn_mask,
is_training=False, temperature=temp,
threshold=thr, topk=topk)
else:
inter_embed, _, _, _, new_aux_info = model(golden_embeds, shared_embeds, golden_mask,
is_training=False, temperature=temp,
threshold=thr, topk=topk)
print(new_aux_info['influence'][:, blue_indices].shape, ' >>>> influence ',
new_aux_info['influence'][:, blue_indices])
print(new_aux_info['meaningful_influence'], ' >>>> meaningful influence ',
torch.mean(new_aux_info['meaningful_influence']))
# inter_embed = torch.clone(floral_embeds)
# inter_embed[:, blue_indices] = shared_embeds[:, blue_indices]
# inter_embed[:, blue_indices, start_dim:end_dim] = floral_embeds[:, blue_indices, start_dim:end_dim]
image = pipe(
prompt_embeds=inter_embed,
pooled_prompt_embeds=empty_pooled_clip,
height=512,
width=512,
guidance_scale=3.5,
num_inference_steps=20,
max_sequence_length=512,
generator=torch.Generator("cuda").manual_seed(1995),
).images[0]
images.append(image)
aligned_image = create_image_grid(images, cols=len(images) // 2)
os.makedirs('samples', exist_ok=True)
aligned_image.save("samples/image%s.jpg" % INDEX)