import torch from torch import nn #from refiner import Qwen2Connector import torch import torch.nn as nn import torch.nn.functional as F import torch import torch.nn as nn import torch.nn.functional as F import torch import torch.nn as nn import torch.nn.functional as F class MultiHeadSelfAttention(nn.Module): def __init__(self, embed_dim=2560, num_heads=20): super().__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads # Linear projections for Q, K, V self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) # Output projection self.out_proj = nn.Linear(embed_dim, embed_dim) self.scale = self.head_dim ** -0.5 def forward(self, x, mask=None, return_attention=True): """ Args: x: Input tensor of shape [b, seq_len, embed_dim] mask: Attention mask of shape [b, seq_len], where 1 means attend, 0 means ignore return_attention: Whether to return attention weights Returns: output: [b, seq_len, embed_dim] attn_weights: [b*num_heads, seq_len, seq_len] (if return_attention=True) """ b, seq_len, embed_dim = x.shape # Project to Q, K, V Q = self.q_proj(x) # [b, seq_len, embed_dim] K = self.k_proj(x) # [b, seq_len, embed_dim] V = self.v_proj(x) # [b, seq_len, embed_dim] # Reshape and transpose for multi-head attention # [b, seq_len, embed_dim] -> [b, seq_len, num_heads, head_dim] -> [b, num_heads, seq_len, head_dim] Q = Q.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K = K.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V = V.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Reshape for batch computation: [b, num_heads, seq_len, head_dim] -> [b*num_heads, seq_len, head_dim] Q = Q.reshape(b * self.num_heads, seq_len, self.head_dim) K = K.reshape(b * self.num_heads, seq_len, self.head_dim) V = V.reshape(b * self.num_heads, seq_len, self.head_dim) # Compute attention scores: Q @ K^T attn_scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale # [b*num_heads, seq_len, seq_len] # Apply mask if provided if mask is not None: # Key mask (column masking): which keys can be attended to key_mask = mask.unsqueeze(1).unsqueeze(2) # [b, 1, 1, seq_len] # Query mask (row masking): which queries are valid query_mask = mask.unsqueeze(1).unsqueeze(3) # [b, 1, seq_len, 1] # Combine both masks: a position can attend only if BOTH query and key are valid # Shape: [b, 1, seq_len, seq_len] final_mask = query_mask.bool() & key_mask.bool() # Broadcasting handles the dimensions # Expand to all heads and reshape final_mask = final_mask.expand(b, self.num_heads, seq_len, seq_len) final_mask = final_mask.reshape(b * self.num_heads, seq_len, seq_len) attn_scores = attn_scores.masked_fill(~final_mask, float('-inf')) # Apply softmax attn_weights = F.softmax(attn_scores, dim=-1) # [b*num_heads, seq_len, seq_len] # Handle NaN from softmax (when entire row is -inf) attn_weights = torch.nan_to_num(attn_weights, nan=0.0) # Apply attention to values attn_output = torch.bmm(attn_weights, V) # [b*num_heads, seq_len, head_dim] # Reshape back: [b*num_heads, seq_len, head_dim] -> [b, num_heads, seq_len, head_dim] attn_output = attn_output.view(b, self.num_heads, seq_len, self.head_dim) # Transpose and reshape: [b, num_heads, seq_len, head_dim] -> [b, seq_len, num_heads, head_dim] -> [b, seq_len, embed_dim] attn_output = attn_output.transpose(1, 2).contiguous().view(b, seq_len, embed_dim) # Final output projection output = self.out_proj(attn_output) # [b, seq_len, embed_dim] if return_attention: return output, attn_weights # attn_weights is [b*num_heads, seq_len, seq_len] else: return output class ConceptAligner222(nn.Module): def __init__(self, custom_pool=1, input_dim=2560, hidden_size=2560): super().__init__() if input_dim == 2560: hidden_size = 2560 self.num_heads = 20 self.model_class = 'gemma3' depth = 2 identity_mapping = False elif input_dim == 4096: hidden_size = 3072 self.num_heads = 24 self.model_class = 't5' depth = 1 identity_mapping = True self.text_connector = Qwen2Connector(in_channels=input_dim, hidden_size=hidden_size, heads_num=self.num_heads, depth=depth, identity_init=identity_mapping) self.final_proj = nn.Sequential(nn.Linear(hidden_size, 4096), nn.SiLU(), nn.Linear(4096, 4096)) self.resampler = MultiHeadSelfAttention(embed_dim=hidden_size, num_heads=self.num_heads) empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu') self.register_buffer('empty_pooled_clip', empty_pooled_clip) self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1, 1]) * 0.01, requires_grad=True) self.proj_norm = nn.LayerNorm(hidden_size) self.custom_pool = custom_pool if self.custom_pool: self.clip_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size * 3), nn.SiLU(), nn.Linear(hidden_size * 3, 768)) self.clip_norm = nn.LayerNorm(768) print('Using custom pooling for CLIP features.') @property def dtype(self): """Return the dtype of the model parameters.""" # return next(self.parameters()).dtype return torch.bfloat16 @property def device(self): """Return the device of the model parameters.""" # return next(self.parameters()).device return self.empty_pooled_clip.device def forward(self, text_features, text_mask, is_training=False, img_seq_len=1024): text_features = self.text_connector(text_features, mask=text_mask, mean_start_id=2 if self.model_class == 'gemma' else 0) text_features = self.proj_norm(text_features) aligned_features, attn = self.resampler(text_features, mask=text_mask, return_attention=True) if is_training: learnable_scale = torch.clip(self.learnable_scale_norm, -1.0, 1.0) visual_concepts = aligned_features + learnable_scale * torch.randn_like(aligned_features) else: visual_concepts = aligned_features prompt_embeds = self.final_proj(visual_concepts) # prompt_embeds = text_features if self.custom_pool: mean_features = (aligned_features * text_mask.unsqueeze(-1)).sum(dim=1) / ( text_mask.sum(dim=1, keepdim=True) + 1e-8) pooled_prompt_embeds = self.clip_proj(mean_features) pooled_prompt_embeds = self.clip_norm(pooled_prompt_embeds) else: pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1) dtype = prompt_embeds.dtype device = prompt_embeds.device text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) total_seq_len = img_seq_len + prompt_embeds.shape[1] text_seq_len = text_mask.shape[1] attention_mask = torch.zeros( len(text_features), 1, 1, total_seq_len, device=text_mask.device, dtype=text_mask.dtype ) # Fill in text portion: where text_mask==0, set to -inf attention_mask[:, :, :, :text_seq_len] = (1 - text_mask).unsqueeze(1).unsqueeze(2) * -10000.0 entropy = -(attn * torch.log(attn + 1e-8)).sum(dim=-1) mask_expanded = text_mask.unsqueeze(1).repeat(1, self.num_heads, 1) mask_expanded = mask_expanded.reshape(len(text_features) * self.num_heads, text_seq_len) valid_entropy = entropy[mask_expanded.bool()] return prompt_embeds, attention_mask, pooled_prompt_embeds, text_ids, valid_entropy # return prompt_embeds, pooled_prompt_embeds, text_ids, None import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight class AdaLayerNorm(nn.Module): def __init__(self, embedding_dim: int, time_embedding_dim=4096): super().__init__() if time_embedding_dim is None: time_embedding_dim = embedding_dim self.silu = nn.SiLU() self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True) nn.init.normal_(self.linear.weight, mean=0, std=0.02) nn.init.zeros_(self.linear.bias) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) def forward( self, x: torch.Tensor, timestep_embedding: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(timestep_embedding)) shift, scale = emb.unsqueeze(1).chunk(2, dim=-1) x = self.norm(x) * (1 + scale) + shift return x class GateMLP(nn.Module): def __init__(self, gate_mode='soft', input_dim=64, hidden_dim=1024): super().__init__() self.gate_mode = gate_mode hidden_dim = max(input_dim, min(hidden_dim, 512)) hidden_dim = 512 self.input_norm = nn.LayerNorm(4096) self.norm0 = nn.LayerNorm(input_dim) self.linear1 = nn.Linear(input_dim, hidden_dim) self.activation1 = nn.GELU() self.linear2 = nn.Linear(hidden_dim+4096, hidden_dim) self.activation2 = nn.GELU() self.linear3 = nn.Linear(hidden_dim+4096, hidden_dim) self.activation3 = nn.GELU() self.final_linear = nn.Linear(hidden_dim, 1) nn.init.xavier_uniform_(self.linear1.weight) nn.init.zeros_(self.linear1.bias) nn.init.xavier_uniform_(self.linear2.weight) nn.init.zeros_(self.linear2.bias) nn.init.xavier_uniform_(self.linear3.weight) nn.init.zeros_(self.linear3.bias) nn.init.zeros_(self.final_linear.weight) bias_val = 0.0 if 'soft' in gate_mode else 1.0 nn.init.constant_(self.final_linear.bias, bias_val) def forward(self, x): y = x.transpose(1, 2).flatten(2) y = self.input_norm(y.detach()).unsqueeze(1).repeat(1, x.shape[1],1,1) x = self.linear1(self.norm0(x.detach())) x = self.activation1(x) x = self.linear2(torch.cat([x, y], dim=-1)) x = self.activation2(x) x = self.linear3(torch.cat([x,y], dim=-1)) x = self.activation3(x) x = self.final_linear(x) return x class CrossAttentionWithInfluence(nn.Module): def __init__(self, d_model=4096, num_heads=32, gate_mode='hard'): super().__init__() self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.gate_mode = gate_mode assert d_model % num_heads == 0, "d_model must be divisible by num_heads" # Linear projections for Q, K, V # self.q_proj = nn.Linear(d_model, d_model) # self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) # nn.init.normal_(self.q_proj.weight, mean=0, std=0.02) # nn.init.normal_(self.k_proj.weight, mean=0, std=0.02) # nn.init.zeros_(self.q_proj.bias) # nn.init.zeros_(self.k_proj.bias) nn.init.eye_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) nn.init.eye_(self.v_proj.weight) nn.init.zeros_(self.v_proj.bias) self.mask_mlp = GateMLP(input_dim=d_model // num_heads, hidden_dim=1024, gate_mode=gate_mode) self.scale = self.head_dim ** -0.5 # self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1,1,1])*0.01, requires_grad=True) self.rec_mlp = nn.Sequential(nn.Linear(4096, 4096), nn.SiLU(), nn.Linear(4096, 4096), nn.SiLU(), nn.Linear(4096, 4096) ) def forward(self, x, y, y_mask, temperature=None, threshold=None, topk=None): """ Args: x: shared embedding [b, 300, 4096] y: changing embedding [b, 300, 4096] Returns: output: [b, 300, 4096] y_influence: [b, 32, 300, 300] - influence from y to x """ b, seq_len_x, d_model = x.shape b, seq_len_y, d_model_y = y.shape """ # Q from x only Q = self.q_proj(x) # [b, 300, 4096] seq_len = Q.shape[1] # K, V from concatenation of [x, y] K = self.k_proj(x) # [b, 300, 4096] # Reshape for multi-head attention Q = Q.view(b, Q.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128] K = K.view(b, K.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 600, 128] """ V = self.v_proj(y) # [b, 300, 4096] shared_V = self.v_proj(x) # [b, 300, 4096] textual_concepts = V.view(b, V.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128] shared_concepts = shared_V.view(b, shared_V.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128] expand_y_mask = y_mask.unsqueeze(1).unsqueeze(-1) # [b, 1, 300, 1] # Compute attention scores """ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [b, 32, 300, 300] attn_weights = F.softmax(attn_scores, dim=-1) # [b, 32, 300, 300] # Compute output attn_output = torch.matmul(attn_weights, textual_concepts) # [b, 32, 300, 128] """ diagonal_influence = self.mask_mlp((textual_concepts)) if 'soft' in self.gate_mode: diagonal_influence = 2 * (torch.sigmoid(diagonal_influence * temperature)) # [b, 32, 300, 1] diagonal_influence = (diagonal_influence > 0.1).to( diagonal_influence.dtype) * diagonal_influence # Thresholding soft_influence = diagonal_influence else: soft_influence = torch.sigmoid(diagonal_influence * temperature) if threshold is None: threshold = 0.5 else: print('Using custom threshold for influence gating:', threshold) hard_influence = (soft_influence >= threshold) diagonal_influence = hard_influence + soft_influence - soft_influence.detach() # Straight-through estimator if topk is not None: print(diagonal_influence.shape, ' <<< shape before topk ') top_k_values, top_k_indices = torch.topk(diagonal_influence, topk, dim=1) result = torch.zeros_like(diagonal_influence) result.scatter_(1, top_k_indices, top_k_values) diagonal_influence = result print('Applied top-k sparsification on influence gates with k=', topk) diagonal_output = textual_concepts * diagonal_influence + shared_concepts * ( 1 - diagonal_influence) # [b, 32, 300, 128] da,db,dc,dd = diagonal_output.shape rec_diagonal = self.rec_mlp(diagonal_output.transpose(1,2).flatten(2)[y_mask.bool()].to(x.dtype)) tgt_diagonal = y[y_mask.bool()] diagonal_output = expand_y_mask * diagonal_output + (1 - expand_y_mask) * shared_concepts # [b, 32, 300, 128] mask_bool_expanded = expand_y_mask.expand_as(diagonal_influence).bool() # [b, 32, 300, 1] meaningful_gates = diagonal_influence[mask_bool_expanded] soft_meaningful_gate = soft_influence[mask_bool_expanded] # full_output = self.learnable_scale_norm*attn_output + diagonal_output # [b, 32, 300, 128] full_output = diagonal_output.to(x.dtype) # Reshape back full_output = full_output.transpose(1, 2).contiguous().view(b, y.shape[1], d_model) # [b, 300, 4096] full_output = full_output # Residual connection # Final output projection output = self.out_proj(full_output) # [b, 300, 4096] return output, diagonal_influence.squeeze(-1).transpose(1, 2), meaningful_gates, soft_meaningful_gate, rec_diagonal, tgt_diagonal def init_weights_gaussian(model, mean=0.0, std=0.02): """ Initialize all nn.Linear layers in the model: - weights with Gaussian(mean, std) - biases to 0 """ for m in model.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=mean, std=std) if m.bias is not None: nn.init.constant_(m.bias, 0.0) class ConceptAligner(nn.Module): def __init__(self, per_dim=4): super().__init__() empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu') self.register_buffer('empty_pooled_clip', empty_pooled_clip) test_eps = torch.randn([1, 300, per_dim], dtype=torch.bfloat16).to('cpu')*0.7 self.register_buffer('test_eps', test_eps) self.init_proj = nn.Sequential(nn.Linear(768, 300*16), nn.SiLU()) self.proj = nn.Sequential(nn.Linear(16, 1024), nn.SiLU(), nn.Linear(1024, 1024), nn.SiLU()) self.text_proj = nn.Sequential(nn.Linear(4096, 1024), nn.SiLU(), nn.Linear(1024, 1024), nn.SiLU()) self.proj_mu = nn.Sequential(nn.Linear(1024, per_dim)) self.proj_logvar = nn.Sequential(nn.Linear(1024, per_dim)) self.eps_proj = nn.Sequential(nn.Linear(per_dim, 1024), nn.SiLU(), nn.LayerNorm(1024), nn.Linear(1024, 4096)) init_weights_gaussian(self, mean=0.0, std=0.02) torch.nn.init.constant_(self.eps_proj[-1].weight, 0.0) torch.nn.init.constant_(self.eps_proj[-1].bias, 0.0) @property def dtype(self): """Return the dtype of the model parameters.""" # return next(self.parameters()).dtype return torch.bfloat16 @property def device(self): """Return the device of the model parameters.""" # return next(self.parameters()).device return self.empty_pooled_clip.device def forward(self, text_features, image_features=None, eps=None): #return text_features, None, self.empty_pooled_clip.expand(text_features.shape[0], -1), torch.zeros(text_features.shape[1], 3).to(device=text_features.device, dtype=text_features.dtype), {'mu': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype), 'logvar': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype)} dtype = text_features.dtype device = text_features.device if image_features is not None: visual_hidden = self.proj(self.init_proj(image_features).view(len(image_features), 300, -1)) text_hidden = self.text_proj(text_features.detach()) hidden = visual_hidden - text_hidden mu = self.proj_mu(hidden) logvar = self.proj_logvar(hidden) eps = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu) else: if eps is None: eps = self.test_eps.to(device=device, dtype=dtype) mu = torch.zeros_like(eps) logvar = torch.zeros_like(eps) proj_eps = self.eps_proj(eps) prompt_embeds = text_features + proj_eps pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) aux_info = { 'mu': mu, 'logvar': logvar, 'eps': eps } return prompt_embeds, None, pooled_prompt_embeds, text_ids, aux_info if __name__ == '__main__': from transformers import AutoProcessor from diffusers import FluxPipeline import os from PIL import Image def create_image_grid(images, cols): rows = (len(images) + cols - 1) // cols w, h = images[0].size grid = Image.new('RGB', (cols * w, rows * h)) for i, img in enumerate(images): grid.paste(img, (i % cols * w, i // cols * h)) return grid dim = 4096 num_heads = 32 dtype = torch.bfloat16 model = ConceptAligner().to('cuda').to(dtype) x = torch.randn([5, 300, dim]).to('cuda').to(dtype) y = torch.randn([5, 300, dim]).to('cuda').to(dtype) i = torch.randn([5,768]).to('cuda').to(dtype) y[1] = y[0] m = torch.ones([5, 300]).to('cuda').to(dtype) m[:3,:128] = 0 prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(x, i) print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape) print(prompt_embeds.shape, ' ', pooled_prompt_embeds.shape, ' ', text_ids.shape) for k in aux_info: print(k, aux_info[k].shape, aux_info[k].min(), aux_info[k].max(), aux_info[k].mean()) from text_encoder import LoraT5Embedder from datasets import load_dataset dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]') item = dataset[0:4] another_item = dataset[0:4] from diffusers.models.normalization import RMSNorm clip_processor = AutoProcessor.from_pretrained("./clip-vit-large-patch14") clip_images = clip_processor(images=item['image'], return_tensors="pt").pixel_values.to('cuda:0').to(dtype) texts = [] texts.append("""A heartwarming 3D rendered scene of an elderly farmer and a tiny orange kitten. The farmer, with a gentle smile, walks alongside the kitten in a lush, green garden filled with thriving plants, showcasing a fruitful harvest. The intricate details of the overalls and the farmer's worn, weathered face tell a story of years spent tending to the land. the farmer is wearing a blue shirt""") texts.append("""A unique, intricately detailed creature resembling a reptile, possibly a lizard or a gecko. It has a vibrant blue and green scaled body, with large, round, and expressive eyes that are a deep shade of blue. The backdrop is a soft, blurred forest setting, suggesting a serene and mystical ambiance. the creature is wearing a golden crown""") texts.append("""Deep in the enchanted forest lives a woman who is the moon fairy. Her long blonde hair shines in the starlight, tangled with her flowers that glow with a soft blue glow. Her eyes are the color of the night and shine with the magic of the night. The fairy wears a dress made of moon petals, woven with threads of moonlight that shine with an iridescent glow, a crown of stars adorns her head, shining with the light of the full moon that illuminates the forest. Her wings are translucent like glass, with a pale glow reminiscent of the glow of the moon. HD, 6K, photo, cinematic, poster""") texts.append( """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""") text_encoder = LoraT5Embedder(device='cuda').to(dtype) text_features, _, _, _, image_features, _ = text_encoder(texts, clip_images) print(text_features.shape, image_features.shape, ' >>>>>>>>> text input') images = [] pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16) pipe.to('cuda') for txt_feat, img_feat in zip(text_features, image_features): prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(txt_feat.unsqueeze(0), img_feat.unsqueeze(0)) image = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, height=512, width=512, guidance_scale=3.5, num_inference_steps=20, max_sequence_length=512, generator=torch.Generator("cuda").manual_seed(1995), ).images[0] images.append(image) aligned_image = create_image_grid(images, cols=len(images) // 2) os.makedirs('samples', exist_ok=True) aligned_image.save("samples/image%.jpg") raise SystemExit influence_matrix = aux_info['influence'] bin_influence_matrix = (influence_matrix > 0.1).float() mean_alive = bin_influence_matrix.sum(dim=-1).mean() max_alive = bin_influence_matrix.sum(dim=-1).max() min_alive = bin_influence_matrix.sum(dim=-1).min() max_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).max() mean_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).mean() min_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).min() print( f"Mean alive heads per token: {mean_alive:.2f}, Max alive heads per token: {max_alive:.2f}, Min alive heads per token: {min_alive:.2f}") print( f"Mean alive tokens: {mean_token_alive:.2f}, Max alive tokens: {max_token_alive:.2f}, Min alive tokens: {min_token_alive:.2f}") import os CHECKPOINT_PATH = 'runs/00393/checkpoint-6000' from safetensors.torch import load_file # Load adapter (model.safetensors) adapter_path = os.path.join(CHECKPOINT_PATH, "model_1.safetensors") if os.path.exists(adapter_path): adapter_state = load_file(adapter_path) model.load_state_dict(adapter_state, strict=True) print("Adapter loaded successfully!") print(model.influence_net.v_proj.weight, ' <<< weight ') print(model.influence_net.v_proj.bias, ' <<< bias ') print(model.influence_net.out_proj.weight, ' <<< out weight ') print(model.influence_net.out_proj.bias, ' <<< out bias ') print(model.influence_net.mask_mlp.linear3.weight, ' <<< gate weight 3 ') print(model.influence_net.mask_mlp.linear3.bias, ' <<< gate bias ') z = torch.randn([3, num_heads, 300, 4096 // num_heads]).to('cuda').to(dtype) gate_values = model.influence_net.mask_mlp(z) gate_values = 2 * (torch.sigmoid(gate_values)) print(gate_values, ' <<< gate values ', gate_values.shape, ' ', torch.mean(gate_values)) from diffusers import FluxPipeline from PIL import Image reserved_memory = torch.cuda.memory_reserved(0) / (1024 ** 3) print(f"Reserved GPU memory: {reserved_memory:.2f} GB") from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel import torch from text_encoder import LoraT5Embedder text_encoder = LoraT5Embedder(device='cuda').to(torch.bfloat16) texts = [] texts.append("""A heartwarming 3D rendered scene of an elderly farmer and a tiny orange kitten. The farmer, with a gentle smile, walks alongside the kitten in a lush, green garden filled with thriving plants, showcasing a fruitful harvest. The intricate details of the overalls and the farmer's worn, weathered face tell a story of years spent tending to the land. the farmer is wearing a blue shirt""") texts.append("""A unique, intricately detailed creature resembling a reptile, possibly a lizard or a gecko. It has a vibrant blue and green scaled body, with large, round, and expressive eyes that are a deep shade of blue. The backdrop is a soft, blurred forest setting, suggesting a serene and mystical ambiance. the creature is wearing a golden crown""") texts.append("""Deep in the enchanted forest lives a woman who is the moon fairy. Her long blonde hair shines in the starlight, tangled with her flowers that glow with a soft blue glow. Her eyes are the color of the night and shine with the magic of the night. The fairy wears a dress made of moon petals, woven with threads of moonlight that shine with an iridescent glow, a crown of stars adorns her head, shining with the light of the full moon that illuminates the forest. Her wings are translucent like glass, with a pale glow reminiscent of the glow of the moon. HD, 6K, photo, cinematic, poster""") texts.append( """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""") texts.append( """In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is bare, revealing the natural curve of its muscles and the sheen of its coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and quiet power.""") INDEX = 0 text = texts[INDEX] with torch.no_grad(): floral_embeds, _,_,_,_,attn_mask = text_encoder(text, ) print(attn_mask.shape, ' >>>> ', attn_mask) print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ') nopad_floral_embeds, nopad_shared_embeds, nopad_attn_mask = text_encoder(text, padding=False) print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ') """ _,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False) print(aux_info['meaningful_influence'].shape, ' <<< influence shape ', aux_info['meaningful_influence'][:100],' ',torch.mean(aux_info['meaningful_influence'])) floral_embeds, shared_embeds, attn_mask = text_encoder([""], padding='max_length') _,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False) print(aux_info['meaningful_influence'].shape, ' <<< empty influence shape ', aux_info['meaningful_influence'],' ',torch.mean(aux_info['meaningful_influence'])) raise SystemExit """ text2s = [] text2s.append("""A heartwarming 3D rendered scene of an elderly farmer and a tiny orange kitten. The farmer, with a gentle smile, walks alongside the kitten in a lush, green garden filled with thriving plants, showcasing a fruitful harvest. The intricate details of the overalls and the farmer's worn, weathered face tell a story of years spent tending to the land. the farmer is wearing a red shirt""") text2s.append("""A unique, intricately detailed creature resembling a reptile, possibly a lizard or a gecko. It has a vibrant blue and green scaled body, with large, round, and expressive eyes that are a deep shade of blue. The backdrop is a soft, blurred forest setting, suggesting a serene and mystical ambiance. the creature is wearing a floral crown""") text2s.append("""Deep in the enchanted forest lives a woman who is the moon fairy. Her long black hair shines in the starlight, tangled with her flowers that glow with a soft blue glow. Her eyes are the color of the night and shine with the magic of the night. The fairy wears a dress made of moon petals, woven with threads of moonlight that shine with an iridescent glow, a crown of stars adorns her head, shining with the light of the full moon that illuminates the forest. Her wings are translucent like glass, with a pale glow reminiscent of the glow of the moon. HD, 6K, photo, cinematic, poster""") text2s.append( """In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a gray, rainy sky outside, with raindrops streaking down the glass and blurred trees beyond. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and introspective, capturing a cozy moment of quiet observation during a rainy afternoon.""") text2s.append( """In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is adorned with a bright red saddle, contrasting sharply against its white coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and a striking touch of color that adds visual drama.""") text2 = text2s[INDEX] with torch.no_grad(): golden_embeds, shared_embeds, golden_mask = text_encoder(text2, padding='max_length') print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ') nopad_golden_embeds, nopad_shared_embeds, nopad_golden_mask = text_encoder(text2, padding=False) print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ') batch_encoding = text_encoder.t5_tokenizer( text, truncation=True, max_length=text_encoder.max_length, return_tensors="pt", ) input_ids = batch_encoding["input_ids"][0] # Get the token IDs # Convert token IDs back to tokens to see what they are tokens_floral = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids) batch_encoding = text_encoder.t5_tokenizer( text2, truncation=True, max_length=text_encoder.max_length, return_tensors="pt", ) input_ids = batch_encoding["input_ids"][0] # Get the token IDs tokens_golden = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids) # Convert token IDs back to tokens to see what they are # Find the index of specific words def find_token_indices(tokens, word): """Find all indices where a word or its token appears""" indices = [] # T5 tokenizer might split words or add special characters word_token = text_encoder.t5_tokenizer.encode(word, add_special_tokens=False)[0] word_token_str = text_encoder.t5_tokenizer.convert_ids_to_tokens([word_token])[0] for i, token in enumerate(tokens): if token == word_token_str or word.lower() in token.lower(): indices.append(i) return indices key1s = ['blue', 'golden', 'blonde', 'clear', 'horse'] key2s = ['red', 'floral', 'black', 'rainy', 'red'] # Find indices for "blue" blue_indices = find_token_indices(tokens_floral, key1s[INDEX])[-1] print(f"Indices for 'blue': {blue_indices}") # Find indices for "red" (won't be found in this text) red_indices = find_token_indices(tokens_golden, key2s[INDEX])[-1] print(f"Indices for 'red': {red_indices}") pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16) pipe.to('cuda') adapter_path = os.path.join(CHECKPOINT_PATH, "model.safetensors") if os.path.exists(adapter_path): adapter_state = load_file(adapter_path) pipe.transformer.load_state_dict(adapter_state, strict=True) print("Transformer loaded successfully!") images = [] empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu').to('cuda').to(torch.bfloat16) print("Generating image with concatenation...") images = [] # for cur_prompt_embed in [floral_embeds, nopad_floral_embeds # , inter_embed, golden_embeds, nopad_golden_embeds]: # for (start_dim, end_dim) in [(0,4096), (1024,4096), (2048, 4096), (1024, 2048)]: for emb in ['floral', 'golden']: for temp in [2.5]: for thr in [-1, 0.5, 0.75, 0.85, 0.95]: for topk in [None]: print('>>>> Temperature: ', temp, topk) if 'floral' in emb: inter_embed, _, _, _, new_aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False, temperature=temp, threshold=thr, topk=topk) else: inter_embed, _, _, _, new_aux_info = model(golden_embeds, shared_embeds, golden_mask, is_training=False, temperature=temp, threshold=thr, topk=topk) print(new_aux_info['influence'][:, blue_indices].shape, ' >>>> influence ', new_aux_info['influence'][:, blue_indices]) print(new_aux_info['meaningful_influence'], ' >>>> meaningful influence ', torch.mean(new_aux_info['meaningful_influence'])) # inter_embed = torch.clone(floral_embeds) # inter_embed[:, blue_indices] = shared_embeds[:, blue_indices] # inter_embed[:, blue_indices, start_dim:end_dim] = floral_embeds[:, blue_indices, start_dim:end_dim] image = pipe( prompt_embeds=inter_embed, pooled_prompt_embeds=empty_pooled_clip, height=512, width=512, guidance_scale=3.5, num_inference_steps=20, max_sequence_length=512, generator=torch.Generator("cuda").manual_seed(1995), ).images[0] images.append(image) aligned_image = create_image_grid(images, cols=len(images) // 2) os.makedirs('samples', exist_ok=True) aligned_image.save("samples/image%s.jpg" % INDEX)