Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn | |
| from refiner import Qwen2Connector | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MultiHeadSelfAttention(nn.Module): | |
| def __init__(self, embed_dim=2560, num_heads=20): | |
| super().__init__() | |
| assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| # Linear projections for Q, K, V | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| # Output projection | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| self.scale = self.head_dim ** -0.5 | |
| def forward(self, x, mask=None, return_attention=True): | |
| """ | |
| Args: | |
| x: Input tensor of shape [b, seq_len, embed_dim] | |
| mask: Attention mask of shape [b, seq_len], where 1 means attend, 0 means ignore | |
| return_attention: Whether to return attention weights | |
| Returns: | |
| output: [b, seq_len, embed_dim] | |
| attn_weights: [b*num_heads, seq_len, seq_len] (if return_attention=True) | |
| """ | |
| b, seq_len, embed_dim = x.shape | |
| # Project to Q, K, V | |
| Q = self.q_proj(x) # [b, seq_len, embed_dim] | |
| K = self.k_proj(x) # [b, seq_len, embed_dim] | |
| V = self.v_proj(x) # [b, seq_len, embed_dim] | |
| # Reshape and transpose for multi-head attention | |
| # [b, seq_len, embed_dim] -> [b, seq_len, num_heads, head_dim] -> [b, num_heads, seq_len, head_dim] | |
| Q = Q.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| K = K.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| V = V.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| # Reshape for batch computation: [b, num_heads, seq_len, head_dim] -> [b*num_heads, seq_len, head_dim] | |
| Q = Q.reshape(b * self.num_heads, seq_len, self.head_dim) | |
| K = K.reshape(b * self.num_heads, seq_len, self.head_dim) | |
| V = V.reshape(b * self.num_heads, seq_len, self.head_dim) | |
| # Compute attention scores: Q @ K^T | |
| attn_scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale # [b*num_heads, seq_len, seq_len] | |
| # Apply mask if provided | |
| if mask is not None: | |
| # Key mask (column masking): which keys can be attended to | |
| key_mask = mask.unsqueeze(1).unsqueeze(2) # [b, 1, 1, seq_len] | |
| # Query mask (row masking): which queries are valid | |
| query_mask = mask.unsqueeze(1).unsqueeze(3) # [b, 1, seq_len, 1] | |
| # Combine both masks: a position can attend only if BOTH query and key are valid | |
| # Shape: [b, 1, seq_len, seq_len] | |
| final_mask = query_mask.bool() & key_mask.bool() # Broadcasting handles the dimensions | |
| # Expand to all heads and reshape | |
| final_mask = final_mask.expand(b, self.num_heads, seq_len, seq_len) | |
| final_mask = final_mask.reshape(b * self.num_heads, seq_len, seq_len) | |
| attn_scores = attn_scores.masked_fill(~final_mask, float('-inf')) | |
| # Apply softmax | |
| attn_weights = F.softmax(attn_scores, dim=-1) # [b*num_heads, seq_len, seq_len] | |
| # Handle NaN from softmax (when entire row is -inf) | |
| attn_weights = torch.nan_to_num(attn_weights, nan=0.0) | |
| # Apply attention to values | |
| attn_output = torch.bmm(attn_weights, V) # [b*num_heads, seq_len, head_dim] | |
| # Reshape back: [b*num_heads, seq_len, head_dim] -> [b, num_heads, seq_len, head_dim] | |
| attn_output = attn_output.view(b, self.num_heads, seq_len, self.head_dim) | |
| # Transpose and reshape: [b, num_heads, seq_len, head_dim] -> [b, seq_len, num_heads, head_dim] -> [b, seq_len, embed_dim] | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(b, seq_len, embed_dim) | |
| # Final output projection | |
| output = self.out_proj(attn_output) # [b, seq_len, embed_dim] | |
| if return_attention: | |
| return output, attn_weights # attn_weights is [b*num_heads, seq_len, seq_len] | |
| else: | |
| return output | |
| class ConceptAligner222(nn.Module): | |
| def __init__(self, custom_pool=1, input_dim=2560, hidden_size=2560): | |
| super().__init__() | |
| if input_dim == 2560: | |
| hidden_size = 2560 | |
| self.num_heads = 20 | |
| self.model_class = 'gemma3' | |
| depth = 2 | |
| identity_mapping = False | |
| elif input_dim == 4096: | |
| hidden_size = 3072 | |
| self.num_heads = 24 | |
| self.model_class = 't5' | |
| depth = 1 | |
| identity_mapping = True | |
| self.text_connector = Qwen2Connector(in_channels=input_dim, hidden_size=hidden_size, heads_num=self.num_heads, | |
| depth=depth, identity_init=identity_mapping) | |
| self.final_proj = nn.Sequential(nn.Linear(hidden_size, 4096), nn.SiLU(), nn.Linear(4096, 4096)) | |
| self.resampler = MultiHeadSelfAttention(embed_dim=hidden_size, num_heads=self.num_heads) | |
| empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu') | |
| self.register_buffer('empty_pooled_clip', empty_pooled_clip) | |
| self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1, 1]) * 0.01, requires_grad=True) | |
| self.proj_norm = nn.LayerNorm(hidden_size) | |
| self.custom_pool = custom_pool | |
| if self.custom_pool: | |
| self.clip_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size * 3), nn.SiLU(), | |
| nn.Linear(hidden_size * 3, 768)) | |
| self.clip_norm = nn.LayerNorm(768) | |
| print('Using custom pooling for CLIP features.') | |
| def dtype(self): | |
| """Return the dtype of the model parameters.""" | |
| # return next(self.parameters()).dtype | |
| return torch.bfloat16 | |
| def device(self): | |
| """Return the device of the model parameters.""" | |
| # return next(self.parameters()).device | |
| return self.empty_pooled_clip.device | |
| def forward(self, text_features, text_mask, is_training=False, img_seq_len=1024): | |
| text_features = self.text_connector(text_features, mask=text_mask, | |
| mean_start_id=2 if self.model_class == 'gemma' else 0) | |
| text_features = self.proj_norm(text_features) | |
| aligned_features, attn = self.resampler(text_features, mask=text_mask, return_attention=True) | |
| if is_training: | |
| learnable_scale = torch.clip(self.learnable_scale_norm, -1.0, 1.0) | |
| visual_concepts = aligned_features + learnable_scale * torch.randn_like(aligned_features) | |
| else: | |
| visual_concepts = aligned_features | |
| prompt_embeds = self.final_proj(visual_concepts) | |
| # prompt_embeds = text_features | |
| if self.custom_pool: | |
| mean_features = (aligned_features * text_mask.unsqueeze(-1)).sum(dim=1) / ( | |
| text_mask.sum(dim=1, keepdim=True) + 1e-8) | |
| pooled_prompt_embeds = self.clip_proj(mean_features) | |
| pooled_prompt_embeds = self.clip_norm(pooled_prompt_embeds) | |
| else: | |
| pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1) | |
| dtype = prompt_embeds.dtype | |
| device = prompt_embeds.device | |
| text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) | |
| total_seq_len = img_seq_len + prompt_embeds.shape[1] | |
| text_seq_len = text_mask.shape[1] | |
| attention_mask = torch.zeros( | |
| len(text_features), 1, 1, total_seq_len, | |
| device=text_mask.device, | |
| dtype=text_mask.dtype | |
| ) | |
| # Fill in text portion: where text_mask==0, set to -inf | |
| attention_mask[:, :, :, :text_seq_len] = (1 - text_mask).unsqueeze(1).unsqueeze(2) * -10000.0 | |
| entropy = -(attn * torch.log(attn + 1e-8)).sum(dim=-1) | |
| mask_expanded = text_mask.unsqueeze(1).repeat(1, self.num_heads, 1) | |
| mask_expanded = mask_expanded.reshape(len(text_features) * self.num_heads, text_seq_len) | |
| valid_entropy = entropy[mask_expanded.bool()] | |
| return prompt_embeds, attention_mask, pooled_prompt_embeds, text_ids, valid_entropy | |
| # return prompt_embeds, pooled_prompt_embeds, text_ids, None | |
| import torch | |
| import torch.nn as nn | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| class AdaLayerNorm(nn.Module): | |
| def __init__(self, embedding_dim: int, time_embedding_dim=4096): | |
| super().__init__() | |
| if time_embedding_dim is None: | |
| time_embedding_dim = embedding_dim | |
| self.silu = nn.SiLU() | |
| self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True) | |
| nn.init.normal_(self.linear.weight, mean=0, std=0.02) | |
| nn.init.zeros_(self.linear.bias) | |
| self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
| def forward( | |
| self, x: torch.Tensor, timestep_embedding: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| emb = self.linear(self.silu(timestep_embedding)) | |
| shift, scale = emb.unsqueeze(1).chunk(2, dim=-1) | |
| x = self.norm(x) * (1 + scale) + shift | |
| return x | |
| class GateMLP(nn.Module): | |
| def __init__(self, gate_mode='soft', input_dim=64, hidden_dim=1024): | |
| super().__init__() | |
| self.gate_mode = gate_mode | |
| hidden_dim = max(input_dim, min(hidden_dim, 512)) | |
| hidden_dim = 512 | |
| self.input_norm = nn.LayerNorm(4096) | |
| self.norm0 = nn.LayerNorm(input_dim) | |
| self.linear1 = nn.Linear(input_dim, hidden_dim) | |
| self.activation1 = nn.GELU() | |
| self.linear2 = nn.Linear(hidden_dim+4096, hidden_dim) | |
| self.activation2 = nn.GELU() | |
| self.linear3 = nn.Linear(hidden_dim+4096, hidden_dim) | |
| self.activation3 = nn.GELU() | |
| self.final_linear = nn.Linear(hidden_dim, 1) | |
| nn.init.xavier_uniform_(self.linear1.weight) | |
| nn.init.zeros_(self.linear1.bias) | |
| nn.init.xavier_uniform_(self.linear2.weight) | |
| nn.init.zeros_(self.linear2.bias) | |
| nn.init.xavier_uniform_(self.linear3.weight) | |
| nn.init.zeros_(self.linear3.bias) | |
| nn.init.zeros_(self.final_linear.weight) | |
| bias_val = 0.0 if 'soft' in gate_mode else 1.0 | |
| nn.init.constant_(self.final_linear.bias, bias_val) | |
| def forward(self, x): | |
| y = x.transpose(1, 2).flatten(2) | |
| y = self.input_norm(y.detach()).unsqueeze(1).repeat(1, x.shape[1],1,1) | |
| x = self.linear1(self.norm0(x.detach())) | |
| x = self.activation1(x) | |
| x = self.linear2(torch.cat([x, y], dim=-1)) | |
| x = self.activation2(x) | |
| x = self.linear3(torch.cat([x,y], dim=-1)) | |
| x = self.activation3(x) | |
| x = self.final_linear(x) | |
| return x | |
| class CrossAttentionWithInfluence(nn.Module): | |
| def __init__(self, d_model=4096, num_heads=32, gate_mode='hard'): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.head_dim = d_model // num_heads | |
| self.gate_mode = gate_mode | |
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads" | |
| # Linear projections for Q, K, V | |
| # self.q_proj = nn.Linear(d_model, d_model) | |
| # self.k_proj = nn.Linear(d_model, d_model) | |
| self.v_proj = nn.Linear(d_model, d_model) | |
| self.out_proj = nn.Linear(d_model, d_model) | |
| # nn.init.normal_(self.q_proj.weight, mean=0, std=0.02) | |
| # nn.init.normal_(self.k_proj.weight, mean=0, std=0.02) | |
| # nn.init.zeros_(self.q_proj.bias) | |
| # nn.init.zeros_(self.k_proj.bias) | |
| nn.init.eye_(self.out_proj.weight) | |
| nn.init.zeros_(self.out_proj.bias) | |
| nn.init.eye_(self.v_proj.weight) | |
| nn.init.zeros_(self.v_proj.bias) | |
| self.mask_mlp = GateMLP(input_dim=d_model // num_heads, hidden_dim=1024, gate_mode=gate_mode) | |
| self.scale = self.head_dim ** -0.5 | |
| # self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1,1,1])*0.01, requires_grad=True) | |
| self.rec_mlp = nn.Sequential(nn.Linear(4096, 4096), nn.SiLU(), | |
| nn.Linear(4096, 4096), nn.SiLU(), | |
| nn.Linear(4096, 4096) | |
| ) | |
| def forward(self, x, y, y_mask, temperature=None, threshold=None, topk=None): | |
| """ | |
| Args: | |
| x: shared embedding [b, 300, 4096] | |
| y: changing embedding [b, 300, 4096] | |
| Returns: | |
| output: [b, 300, 4096] | |
| y_influence: [b, 32, 300, 300] - influence from y to x | |
| """ | |
| b, seq_len_x, d_model = x.shape | |
| b, seq_len_y, d_model_y = y.shape | |
| """ | |
| # Q from x only | |
| Q = self.q_proj(x) # [b, 300, 4096] | |
| seq_len = Q.shape[1] | |
| # K, V from concatenation of [x, y] | |
| K = self.k_proj(x) # [b, 300, 4096] | |
| # Reshape for multi-head attention | |
| Q = Q.view(b, Q.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128] | |
| K = K.view(b, K.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 600, 128] | |
| """ | |
| V = self.v_proj(y) # [b, 300, 4096] | |
| shared_V = self.v_proj(x) # [b, 300, 4096] | |
| textual_concepts = V.view(b, V.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128] | |
| shared_concepts = shared_V.view(b, shared_V.shape[1], self.num_heads, self.head_dim).transpose(1, | |
| 2) # [b, 32, 300, 128] | |
| expand_y_mask = y_mask.unsqueeze(1).unsqueeze(-1) # [b, 1, 300, 1] | |
| # Compute attention scores | |
| """ | |
| attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [b, 32, 300, 300] | |
| attn_weights = F.softmax(attn_scores, dim=-1) # [b, 32, 300, 300] | |
| # Compute output | |
| attn_output = torch.matmul(attn_weights, textual_concepts) # [b, 32, 300, 128] | |
| """ | |
| diagonal_influence = self.mask_mlp((textual_concepts)) | |
| if 'soft' in self.gate_mode: | |
| diagonal_influence = 2 * (torch.sigmoid(diagonal_influence * temperature)) # [b, 32, 300, 1] | |
| diagonal_influence = (diagonal_influence > 0.1).to( | |
| diagonal_influence.dtype) * diagonal_influence # Thresholding | |
| soft_influence = diagonal_influence | |
| else: | |
| soft_influence = torch.sigmoid(diagonal_influence * temperature) | |
| if threshold is None: | |
| threshold = 0.5 | |
| else: | |
| print('Using custom threshold for influence gating:', threshold) | |
| hard_influence = (soft_influence >= threshold) | |
| diagonal_influence = hard_influence + soft_influence - soft_influence.detach() # Straight-through estimator | |
| if topk is not None: | |
| print(diagonal_influence.shape, ' <<< shape before topk ') | |
| top_k_values, top_k_indices = torch.topk(diagonal_influence, topk, dim=1) | |
| result = torch.zeros_like(diagonal_influence) | |
| result.scatter_(1, top_k_indices, top_k_values) | |
| diagonal_influence = result | |
| print('Applied top-k sparsification on influence gates with k=', topk) | |
| diagonal_output = textual_concepts * diagonal_influence + shared_concepts * ( | |
| 1 - diagonal_influence) # [b, 32, 300, 128] | |
| da,db,dc,dd = diagonal_output.shape | |
| rec_diagonal = self.rec_mlp(diagonal_output.transpose(1,2).flatten(2)[y_mask.bool()].to(x.dtype)) | |
| tgt_diagonal = y[y_mask.bool()] | |
| diagonal_output = expand_y_mask * diagonal_output + (1 - expand_y_mask) * shared_concepts # [b, 32, 300, 128] | |
| mask_bool_expanded = expand_y_mask.expand_as(diagonal_influence).bool() # [b, 32, 300, 1] | |
| meaningful_gates = diagonal_influence[mask_bool_expanded] | |
| soft_meaningful_gate = soft_influence[mask_bool_expanded] | |
| # full_output = self.learnable_scale_norm*attn_output + diagonal_output # [b, 32, 300, 128] | |
| full_output = diagonal_output.to(x.dtype) | |
| # Reshape back | |
| full_output = full_output.transpose(1, 2).contiguous().view(b, y.shape[1], d_model) # [b, 300, 4096] | |
| full_output = full_output # Residual connection | |
| # Final output projection | |
| output = self.out_proj(full_output) # [b, 300, 4096] | |
| return output, diagonal_influence.squeeze(-1).transpose(1, 2), meaningful_gates, soft_meaningful_gate, rec_diagonal, tgt_diagonal | |
| def init_weights_gaussian(model, mean=0.0, std=0.02): | |
| """ | |
| Initialize all nn.Linear layers in the model: | |
| - weights with Gaussian(mean, std) | |
| - biases to 0 | |
| """ | |
| for m in model.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, mean=mean, std=std) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0.0) | |
| class ConceptAligner(nn.Module): | |
| def __init__(self, per_dim=4): | |
| super().__init__() | |
| empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu') | |
| self.register_buffer('empty_pooled_clip', empty_pooled_clip) | |
| test_eps = torch.randn([1, 300, per_dim], dtype=torch.bfloat16).to('cpu')*0.7 | |
| self.register_buffer('test_eps', test_eps) | |
| self.init_proj = nn.Sequential(nn.Linear(768, 300*16), nn.SiLU()) | |
| self.proj = nn.Sequential(nn.Linear(16, 1024), nn.SiLU(), | |
| nn.Linear(1024, 1024), nn.SiLU()) | |
| self.text_proj = nn.Sequential(nn.Linear(4096, 1024), nn.SiLU(), | |
| nn.Linear(1024, 1024), nn.SiLU()) | |
| self.proj_mu = nn.Sequential(nn.Linear(1024, per_dim)) | |
| self.proj_logvar = nn.Sequential(nn.Linear(1024, per_dim)) | |
| self.eps_proj = nn.Sequential(nn.Linear(per_dim, 1024), nn.SiLU(), | |
| nn.LayerNorm(1024), | |
| nn.Linear(1024, 4096)) | |
| init_weights_gaussian(self, mean=0.0, std=0.02) | |
| torch.nn.init.constant_(self.eps_proj[-1].weight, 0.0) | |
| torch.nn.init.constant_(self.eps_proj[-1].bias, 0.0) | |
| def dtype(self): | |
| """Return the dtype of the model parameters.""" | |
| # return next(self.parameters()).dtype | |
| return torch.bfloat16 | |
| def device(self): | |
| """Return the device of the model parameters.""" | |
| # return next(self.parameters()).device | |
| return self.empty_pooled_clip.device | |
| def forward(self, text_features, image_features=None, eps=None): | |
| #return text_features, None, self.empty_pooled_clip.expand(text_features.shape[0], -1), torch.zeros(text_features.shape[1], 3).to(device=text_features.device, dtype=text_features.dtype), {'mu': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype), 'logvar': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype)} | |
| dtype = text_features.dtype | |
| device = text_features.device | |
| if image_features is not None: | |
| visual_hidden = self.proj(self.init_proj(image_features).view(len(image_features), 300, -1)) | |
| text_hidden = self.text_proj(text_features.detach()) | |
| hidden = visual_hidden - text_hidden | |
| mu = self.proj_mu(hidden) | |
| logvar = self.proj_logvar(hidden) | |
| eps = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu) | |
| else: | |
| if eps is None: | |
| eps = self.test_eps.to(device=device, dtype=dtype) | |
| mu = torch.zeros_like(eps) | |
| logvar = torch.zeros_like(eps) | |
| proj_eps = self.eps_proj(eps) | |
| prompt_embeds = text_features + proj_eps | |
| pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1) | |
| text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) | |
| aux_info = { | |
| 'mu': mu, | |
| 'logvar': logvar, | |
| 'eps': eps | |
| } | |
| return prompt_embeds, None, pooled_prompt_embeds, text_ids, aux_info | |
| if __name__ == '__main__': | |
| from transformers import AutoProcessor | |
| from diffusers import FluxPipeline | |
| import os | |
| from PIL import Image | |
| def create_image_grid(images, cols): | |
| rows = (len(images) + cols - 1) // cols | |
| w, h = images[0].size | |
| grid = Image.new('RGB', (cols * w, rows * h)) | |
| for i, img in enumerate(images): | |
| grid.paste(img, (i % cols * w, i // cols * h)) | |
| return grid | |
| dim = 4096 | |
| num_heads = 32 | |
| dtype = torch.bfloat16 | |
| model = ConceptAligner().to('cuda').to(dtype) | |
| x = torch.randn([5, 300, dim]).to('cuda').to(dtype) | |
| y = torch.randn([5, 300, dim]).to('cuda').to(dtype) | |
| i = torch.randn([5,768]).to('cuda').to(dtype) | |
| y[1] = y[0] | |
| m = torch.ones([5, 300]).to('cuda').to(dtype) | |
| m[:3,:128] = 0 | |
| prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(x, i) | |
| print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape) | |
| print(prompt_embeds.shape, ' ', pooled_prompt_embeds.shape, ' ', text_ids.shape) | |
| for k in aux_info: | |
| print(k, aux_info[k].shape, aux_info[k].min(), aux_info[k].max(), aux_info[k].mean()) | |
| from text_encoder import LoraT5Embedder | |
| from datasets import load_dataset | |
| dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]') | |
| item = dataset[0:4] | |
| another_item = dataset[0:4] | |
| from diffusers.models.normalization import RMSNorm | |
| clip_processor = AutoProcessor.from_pretrained("./clip-vit-large-patch14") | |
| clip_images = clip_processor(images=item['image'], return_tensors="pt").pixel_values.to('cuda:0').to(dtype) | |
| texts = [] | |
| texts.append("""A heartwarming 3D rendered scene of | |
| an elderly farmer and a tiny orange | |
| kitten. The farmer, with a gentle smile, | |
| walks alongside the kitten in a lush, | |
| green garden filled with thriving plants, | |
| showcasing a fruitful harvest. The | |
| intricate details of the overalls and the | |
| farmer's worn, weathered face tell a | |
| story of years spent tending to the land. the farmer is wearing a blue shirt""") | |
| texts.append("""A unique, intricately detailed creature | |
| resembling a reptile, possibly a lizard or | |
| a gecko. It has a vibrant blue and green | |
| scaled body, with large, round, and | |
| expressive eyes that are a deep shade of | |
| blue. The backdrop is a | |
| soft, blurred forest setting, suggesting a | |
| serene and mystical ambiance. the creature is wearing a golden crown""") | |
| texts.append("""Deep in the enchanted forest lives a woman | |
| who is the moon fairy. Her long blonde hair | |
| shines in the starlight, tangled with her flowers | |
| that glow with a soft blue glow. Her eyes are | |
| the color of the night and shine with the magic | |
| of the night. The fairy wears a dress made of | |
| moon petals, woven with threads of moonlight | |
| that shine with an iridescent glow, a crown of | |
| stars adorns her head, shining with the light of | |
| the full moon that illuminates the forest. Her | |
| wings are translucent like glass, with a pale | |
| glow reminiscent of the glow of the moon. HD, | |
| 6K, photo, cinematic, poster""") | |
| texts.append( | |
| """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""") | |
| text_encoder = LoraT5Embedder(device='cuda').to(dtype) | |
| text_features, _, _, _, image_features, _ = text_encoder(texts, clip_images) | |
| print(text_features.shape, image_features.shape, ' >>>>>>>>> text input') | |
| images = [] | |
| pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16) | |
| pipe.to('cuda') | |
| for txt_feat, img_feat in zip(text_features, image_features): | |
| prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(txt_feat.unsqueeze(0), img_feat.unsqueeze(0)) | |
| image = pipe( | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| height=512, | |
| width=512, | |
| guidance_scale=3.5, | |
| num_inference_steps=20, | |
| max_sequence_length=512, | |
| generator=torch.Generator("cuda").manual_seed(1995), | |
| ).images[0] | |
| images.append(image) | |
| aligned_image = create_image_grid(images, cols=len(images) // 2) | |
| os.makedirs('samples', exist_ok=True) | |
| aligned_image.save("samples/image%.jpg") | |
| raise SystemExit | |
| influence_matrix = aux_info['influence'] | |
| bin_influence_matrix = (influence_matrix > 0.1).float() | |
| mean_alive = bin_influence_matrix.sum(dim=-1).mean() | |
| max_alive = bin_influence_matrix.sum(dim=-1).max() | |
| min_alive = bin_influence_matrix.sum(dim=-1).min() | |
| max_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).max() | |
| mean_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).mean() | |
| min_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).min() | |
| print( | |
| f"Mean alive heads per token: {mean_alive:.2f}, Max alive heads per token: {max_alive:.2f}, Min alive heads per token: {min_alive:.2f}") | |
| print( | |
| f"Mean alive tokens: {mean_token_alive:.2f}, Max alive tokens: {max_token_alive:.2f}, Min alive tokens: {min_token_alive:.2f}") | |
| import os | |
| CHECKPOINT_PATH = 'runs/00393/checkpoint-6000' | |
| from safetensors.torch import load_file | |
| # Load adapter (model.safetensors) | |
| adapter_path = os.path.join(CHECKPOINT_PATH, "model_1.safetensors") | |
| if os.path.exists(adapter_path): | |
| adapter_state = load_file(adapter_path) | |
| model.load_state_dict(adapter_state, strict=True) | |
| print("Adapter loaded successfully!") | |
| print(model.influence_net.v_proj.weight, ' <<< weight ') | |
| print(model.influence_net.v_proj.bias, ' <<< bias ') | |
| print(model.influence_net.out_proj.weight, ' <<< out weight ') | |
| print(model.influence_net.out_proj.bias, ' <<< out bias ') | |
| print(model.influence_net.mask_mlp.linear3.weight, ' <<< gate weight 3 ') | |
| print(model.influence_net.mask_mlp.linear3.bias, ' <<< gate bias ') | |
| z = torch.randn([3, num_heads, 300, 4096 // num_heads]).to('cuda').to(dtype) | |
| gate_values = model.influence_net.mask_mlp(z) | |
| gate_values = 2 * (torch.sigmoid(gate_values)) | |
| print(gate_values, ' <<< gate values ', gate_values.shape, ' ', torch.mean(gate_values)) | |
| from diffusers import FluxPipeline | |
| from PIL import Image | |
| reserved_memory = torch.cuda.memory_reserved(0) / (1024 ** 3) | |
| print(f"Reserved GPU memory: {reserved_memory:.2f} GB") | |
| from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel | |
| import torch | |
| from text_encoder import LoraT5Embedder | |
| text_encoder = LoraT5Embedder(device='cuda').to(torch.bfloat16) | |
| texts = [] | |
| texts.append("""A heartwarming 3D rendered scene of | |
| an elderly farmer and a tiny orange | |
| kitten. The farmer, with a gentle smile, | |
| walks alongside the kitten in a lush, | |
| green garden filled with thriving plants, | |
| showcasing a fruitful harvest. The | |
| intricate details of the overalls and the | |
| farmer's worn, weathered face tell a | |
| story of years spent tending to the land. the farmer is wearing a blue shirt""") | |
| texts.append("""A unique, intricately detailed creature | |
| resembling a reptile, possibly a lizard or | |
| a gecko. It has a vibrant blue and green | |
| scaled body, with large, round, and | |
| expressive eyes that are a deep shade of | |
| blue. The backdrop is a | |
| soft, blurred forest setting, suggesting a | |
| serene and mystical ambiance. the creature is wearing a golden crown""") | |
| texts.append("""Deep in the enchanted forest lives a woman | |
| who is the moon fairy. Her long blonde hair | |
| shines in the starlight, tangled with her flowers | |
| that glow with a soft blue glow. Her eyes are | |
| the color of the night and shine with the magic | |
| of the night. The fairy wears a dress made of | |
| moon petals, woven with threads of moonlight | |
| that shine with an iridescent glow, a crown of | |
| stars adorns her head, shining with the light of | |
| the full moon that illuminates the forest. Her | |
| wings are translucent like glass, with a pale | |
| glow reminiscent of the glow of the moon. HD, | |
| 6K, photo, cinematic, poster""") | |
| texts.append( | |
| """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""") | |
| texts.append( | |
| """In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is bare, revealing the natural curve of its muscles and the sheen of its coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and quiet power.""") | |
| INDEX = 0 | |
| text = texts[INDEX] | |
| with torch.no_grad(): | |
| floral_embeds, _,_,_,_,attn_mask = text_encoder(text, ) | |
| print(attn_mask.shape, ' >>>> ', attn_mask) | |
| print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ') | |
| nopad_floral_embeds, nopad_shared_embeds, nopad_attn_mask = text_encoder(text, padding=False) | |
| print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ') | |
| """ | |
| _,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False) | |
| print(aux_info['meaningful_influence'].shape, ' <<< influence shape ', aux_info['meaningful_influence'][:100],' ',torch.mean(aux_info['meaningful_influence'])) | |
| floral_embeds, shared_embeds, attn_mask = text_encoder([""], padding='max_length') | |
| _,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False) | |
| print(aux_info['meaningful_influence'].shape, ' <<< empty influence shape ', aux_info['meaningful_influence'],' ',torch.mean(aux_info['meaningful_influence'])) | |
| raise SystemExit | |
| """ | |
| text2s = [] | |
| text2s.append("""A heartwarming 3D rendered scene of | |
| an elderly farmer and a tiny orange | |
| kitten. The farmer, with a gentle smile, | |
| walks alongside the kitten in a lush, | |
| green garden filled with thriving plants, | |
| showcasing a fruitful harvest. The | |
| intricate details of the overalls and the | |
| farmer's worn, weathered face tell a | |
| story of years spent tending to the land. the farmer is wearing a red shirt""") | |
| text2s.append("""A unique, intricately detailed creature | |
| resembling a reptile, possibly a lizard or | |
| a gecko. It has a vibrant blue and green | |
| scaled body, with large, round, and | |
| expressive eyes that are a deep shade of | |
| blue. The backdrop is a | |
| soft, blurred forest setting, suggesting a | |
| serene and mystical ambiance. the creature is wearing a floral crown""") | |
| text2s.append("""Deep in the enchanted forest lives a woman | |
| who is the moon fairy. Her long black hair | |
| shines in the starlight, tangled with her flowers | |
| that glow with a soft blue glow. Her eyes are | |
| the color of the night and shine with the magic | |
| of the night. The fairy wears a dress made of | |
| moon petals, woven with threads of moonlight | |
| that shine with an iridescent glow, a crown of | |
| stars adorns her head, shining with the light of | |
| the full moon that illuminates the forest. Her | |
| wings are translucent like glass, with a pale | |
| glow reminiscent of the glow of the moon. HD, | |
| 6K, photo, cinematic, poster""") | |
| text2s.append( | |
| """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a gray, rainy sky outside, with raindrops streaking down the glass and blurred trees beyond. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and introspective, capturing a cozy moment of quiet observation during a rainy afternoon.""") | |
| text2s.append( | |
| """In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is adorned with a bright red saddle, contrasting sharply against its white coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and a striking touch of color that adds visual drama.""") | |
| text2 = text2s[INDEX] | |
| with torch.no_grad(): | |
| golden_embeds, shared_embeds, golden_mask = text_encoder(text2, padding='max_length') | |
| print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ') | |
| nopad_golden_embeds, nopad_shared_embeds, nopad_golden_mask = text_encoder(text2, padding=False) | |
| print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ') | |
| batch_encoding = text_encoder.t5_tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=text_encoder.max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = batch_encoding["input_ids"][0] # Get the token IDs | |
| # Convert token IDs back to tokens to see what they are | |
| tokens_floral = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids) | |
| batch_encoding = text_encoder.t5_tokenizer( | |
| text2, | |
| truncation=True, | |
| max_length=text_encoder.max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = batch_encoding["input_ids"][0] # Get the token IDs | |
| tokens_golden = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids) | |
| # Convert token IDs back to tokens to see what they are | |
| # Find the index of specific words | |
| def find_token_indices(tokens, word): | |
| """Find all indices where a word or its token appears""" | |
| indices = [] | |
| # T5 tokenizer might split words or add special characters | |
| word_token = text_encoder.t5_tokenizer.encode(word, add_special_tokens=False)[0] | |
| word_token_str = text_encoder.t5_tokenizer.convert_ids_to_tokens([word_token])[0] | |
| for i, token in enumerate(tokens): | |
| if token == word_token_str or word.lower() in token.lower(): | |
| indices.append(i) | |
| return indices | |
| key1s = ['blue', 'golden', 'blonde', 'clear', 'horse'] | |
| key2s = ['red', 'floral', 'black', 'rainy', 'red'] | |
| # Find indices for "blue" | |
| blue_indices = find_token_indices(tokens_floral, key1s[INDEX])[-1] | |
| print(f"Indices for 'blue': {blue_indices}") | |
| # Find indices for "red" (won't be found in this text) | |
| red_indices = find_token_indices(tokens_golden, key2s[INDEX])[-1] | |
| print(f"Indices for 'red': {red_indices}") | |
| pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16) | |
| pipe.to('cuda') | |
| adapter_path = os.path.join(CHECKPOINT_PATH, "model.safetensors") | |
| if os.path.exists(adapter_path): | |
| adapter_state = load_file(adapter_path) | |
| pipe.transformer.load_state_dict(adapter_state, strict=True) | |
| print("Transformer loaded successfully!") | |
| images = [] | |
| empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu').to('cuda').to(torch.bfloat16) | |
| print("Generating image with concatenation...") | |
| images = [] | |
| # for cur_prompt_embed in [floral_embeds, nopad_floral_embeds | |
| # , inter_embed, golden_embeds, nopad_golden_embeds]: | |
| # for (start_dim, end_dim) in [(0,4096), (1024,4096), (2048, 4096), (1024, 2048)]: | |
| for emb in ['floral', 'golden']: | |
| for temp in [2.5]: | |
| for thr in [-1, 0.5, 0.75, 0.85, 0.95]: | |
| for topk in [None]: | |
| print('>>>> Temperature: ', temp, topk) | |
| if 'floral' in emb: | |
| inter_embed, _, _, _, new_aux_info = model(floral_embeds, shared_embeds, attn_mask, | |
| is_training=False, temperature=temp, | |
| threshold=thr, topk=topk) | |
| else: | |
| inter_embed, _, _, _, new_aux_info = model(golden_embeds, shared_embeds, golden_mask, | |
| is_training=False, temperature=temp, | |
| threshold=thr, topk=topk) | |
| print(new_aux_info['influence'][:, blue_indices].shape, ' >>>> influence ', | |
| new_aux_info['influence'][:, blue_indices]) | |
| print(new_aux_info['meaningful_influence'], ' >>>> meaningful influence ', | |
| torch.mean(new_aux_info['meaningful_influence'])) | |
| # inter_embed = torch.clone(floral_embeds) | |
| # inter_embed[:, blue_indices] = shared_embeds[:, blue_indices] | |
| # inter_embed[:, blue_indices, start_dim:end_dim] = floral_embeds[:, blue_indices, start_dim:end_dim] | |
| image = pipe( | |
| prompt_embeds=inter_embed, | |
| pooled_prompt_embeds=empty_pooled_clip, | |
| height=512, | |
| width=512, | |
| guidance_scale=3.5, | |
| num_inference_steps=20, | |
| max_sequence_length=512, | |
| generator=torch.Generator("cuda").manual_seed(1995), | |
| ).images[0] | |
| images.append(image) | |
| aligned_image = create_image_grid(images, cols=len(images) // 2) | |
| os.makedirs('samples', exist_ok=True) | |
| aligned_image.save("samples/image%s.jpg" % INDEX) | |