Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- aligner.py +861 -0
- app.py +210 -0
- empty_pooled_clip.pt +3 -0
- pipeline.py +641 -0
- requirements.txt +8 -0
- requirements.txt.py +8 -0
- text_encoder.py +1188 -0
aligner.py
ADDED
|
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from refiner import Qwen2Connector
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 19 |
+
def __init__(self, embed_dim=2560, num_heads=20):
|
| 20 |
+
super().__init__()
|
| 21 |
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 22 |
+
|
| 23 |
+
self.embed_dim = embed_dim
|
| 24 |
+
self.num_heads = num_heads
|
| 25 |
+
self.head_dim = embed_dim // num_heads
|
| 26 |
+
|
| 27 |
+
# Linear projections for Q, K, V
|
| 28 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 29 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 30 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 31 |
+
|
| 32 |
+
# Output projection
|
| 33 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 34 |
+
|
| 35 |
+
self.scale = self.head_dim ** -0.5
|
| 36 |
+
|
| 37 |
+
def forward(self, x, mask=None, return_attention=True):
|
| 38 |
+
"""
|
| 39 |
+
Args:
|
| 40 |
+
x: Input tensor of shape [b, seq_len, embed_dim]
|
| 41 |
+
mask: Attention mask of shape [b, seq_len], where 1 means attend, 0 means ignore
|
| 42 |
+
return_attention: Whether to return attention weights
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
output: [b, seq_len, embed_dim]
|
| 46 |
+
attn_weights: [b*num_heads, seq_len, seq_len] (if return_attention=True)
|
| 47 |
+
"""
|
| 48 |
+
b, seq_len, embed_dim = x.shape
|
| 49 |
+
|
| 50 |
+
# Project to Q, K, V
|
| 51 |
+
Q = self.q_proj(x) # [b, seq_len, embed_dim]
|
| 52 |
+
K = self.k_proj(x) # [b, seq_len, embed_dim]
|
| 53 |
+
V = self.v_proj(x) # [b, seq_len, embed_dim]
|
| 54 |
+
|
| 55 |
+
# Reshape and transpose for multi-head attention
|
| 56 |
+
# [b, seq_len, embed_dim] -> [b, seq_len, num_heads, head_dim] -> [b, num_heads, seq_len, head_dim]
|
| 57 |
+
Q = Q.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 58 |
+
K = K.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 59 |
+
V = V.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 60 |
+
|
| 61 |
+
# Reshape for batch computation: [b, num_heads, seq_len, head_dim] -> [b*num_heads, seq_len, head_dim]
|
| 62 |
+
Q = Q.reshape(b * self.num_heads, seq_len, self.head_dim)
|
| 63 |
+
K = K.reshape(b * self.num_heads, seq_len, self.head_dim)
|
| 64 |
+
V = V.reshape(b * self.num_heads, seq_len, self.head_dim)
|
| 65 |
+
|
| 66 |
+
# Compute attention scores: Q @ K^T
|
| 67 |
+
attn_scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale # [b*num_heads, seq_len, seq_len]
|
| 68 |
+
|
| 69 |
+
# Apply mask if provided
|
| 70 |
+
if mask is not None:
|
| 71 |
+
# Key mask (column masking): which keys can be attended to
|
| 72 |
+
key_mask = mask.unsqueeze(1).unsqueeze(2) # [b, 1, 1, seq_len]
|
| 73 |
+
|
| 74 |
+
# Query mask (row masking): which queries are valid
|
| 75 |
+
query_mask = mask.unsqueeze(1).unsqueeze(3) # [b, 1, seq_len, 1]
|
| 76 |
+
|
| 77 |
+
# Combine both masks: a position can attend only if BOTH query and key are valid
|
| 78 |
+
# Shape: [b, 1, seq_len, seq_len]
|
| 79 |
+
final_mask = query_mask.bool() & key_mask.bool() # Broadcasting handles the dimensions
|
| 80 |
+
|
| 81 |
+
# Expand to all heads and reshape
|
| 82 |
+
final_mask = final_mask.expand(b, self.num_heads, seq_len, seq_len)
|
| 83 |
+
final_mask = final_mask.reshape(b * self.num_heads, seq_len, seq_len)
|
| 84 |
+
|
| 85 |
+
attn_scores = attn_scores.masked_fill(~final_mask, float('-inf'))
|
| 86 |
+
|
| 87 |
+
# Apply softmax
|
| 88 |
+
attn_weights = F.softmax(attn_scores, dim=-1) # [b*num_heads, seq_len, seq_len]
|
| 89 |
+
|
| 90 |
+
# Handle NaN from softmax (when entire row is -inf)
|
| 91 |
+
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
|
| 92 |
+
|
| 93 |
+
# Apply attention to values
|
| 94 |
+
attn_output = torch.bmm(attn_weights, V) # [b*num_heads, seq_len, head_dim]
|
| 95 |
+
|
| 96 |
+
# Reshape back: [b*num_heads, seq_len, head_dim] -> [b, num_heads, seq_len, head_dim]
|
| 97 |
+
attn_output = attn_output.view(b, self.num_heads, seq_len, self.head_dim)
|
| 98 |
+
|
| 99 |
+
# Transpose and reshape: [b, num_heads, seq_len, head_dim] -> [b, seq_len, num_heads, head_dim] -> [b, seq_len, embed_dim]
|
| 100 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(b, seq_len, embed_dim)
|
| 101 |
+
|
| 102 |
+
# Final output projection
|
| 103 |
+
output = self.out_proj(attn_output) # [b, seq_len, embed_dim]
|
| 104 |
+
|
| 105 |
+
if return_attention:
|
| 106 |
+
return output, attn_weights # attn_weights is [b*num_heads, seq_len, seq_len]
|
| 107 |
+
else:
|
| 108 |
+
return output
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class ConceptAligner222(nn.Module):
|
| 112 |
+
def __init__(self, custom_pool=1, input_dim=2560, hidden_size=2560):
|
| 113 |
+
super().__init__()
|
| 114 |
+
if input_dim == 2560:
|
| 115 |
+
hidden_size = 2560
|
| 116 |
+
self.num_heads = 20
|
| 117 |
+
self.model_class = 'gemma3'
|
| 118 |
+
depth = 2
|
| 119 |
+
identity_mapping = False
|
| 120 |
+
|
| 121 |
+
elif input_dim == 4096:
|
| 122 |
+
hidden_size = 3072
|
| 123 |
+
self.num_heads = 24
|
| 124 |
+
self.model_class = 't5'
|
| 125 |
+
depth = 1
|
| 126 |
+
identity_mapping = True
|
| 127 |
+
|
| 128 |
+
self.text_connector = Qwen2Connector(in_channels=input_dim, hidden_size=hidden_size, heads_num=self.num_heads,
|
| 129 |
+
depth=depth, identity_init=identity_mapping)
|
| 130 |
+
self.final_proj = nn.Sequential(nn.Linear(hidden_size, 4096), nn.SiLU(), nn.Linear(4096, 4096))
|
| 131 |
+
self.resampler = MultiHeadSelfAttention(embed_dim=hidden_size, num_heads=self.num_heads)
|
| 132 |
+
empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu')
|
| 133 |
+
self.register_buffer('empty_pooled_clip', empty_pooled_clip)
|
| 134 |
+
self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1, 1]) * 0.01, requires_grad=True)
|
| 135 |
+
self.proj_norm = nn.LayerNorm(hidden_size)
|
| 136 |
+
self.custom_pool = custom_pool
|
| 137 |
+
if self.custom_pool:
|
| 138 |
+
self.clip_proj = nn.Sequential(nn.Linear(hidden_size, hidden_size * 3), nn.SiLU(),
|
| 139 |
+
nn.Linear(hidden_size * 3, 768))
|
| 140 |
+
self.clip_norm = nn.LayerNorm(768)
|
| 141 |
+
print('Using custom pooling for CLIP features.')
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def dtype(self):
|
| 145 |
+
"""Return the dtype of the model parameters."""
|
| 146 |
+
# return next(self.parameters()).dtype
|
| 147 |
+
return torch.bfloat16
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def device(self):
|
| 151 |
+
"""Return the device of the model parameters."""
|
| 152 |
+
# return next(self.parameters()).device
|
| 153 |
+
return self.empty_pooled_clip.device
|
| 154 |
+
|
| 155 |
+
def forward(self, text_features, text_mask, is_training=False, img_seq_len=1024):
|
| 156 |
+
text_features = self.text_connector(text_features, mask=text_mask,
|
| 157 |
+
mean_start_id=2 if self.model_class == 'gemma' else 0)
|
| 158 |
+
text_features = self.proj_norm(text_features)
|
| 159 |
+
aligned_features, attn = self.resampler(text_features, mask=text_mask, return_attention=True)
|
| 160 |
+
if is_training:
|
| 161 |
+
learnable_scale = torch.clip(self.learnable_scale_norm, -1.0, 1.0)
|
| 162 |
+
visual_concepts = aligned_features + learnable_scale * torch.randn_like(aligned_features)
|
| 163 |
+
else:
|
| 164 |
+
visual_concepts = aligned_features
|
| 165 |
+
prompt_embeds = self.final_proj(visual_concepts)
|
| 166 |
+
# prompt_embeds = text_features
|
| 167 |
+
if self.custom_pool:
|
| 168 |
+
mean_features = (aligned_features * text_mask.unsqueeze(-1)).sum(dim=1) / (
|
| 169 |
+
text_mask.sum(dim=1, keepdim=True) + 1e-8)
|
| 170 |
+
pooled_prompt_embeds = self.clip_proj(mean_features)
|
| 171 |
+
pooled_prompt_embeds = self.clip_norm(pooled_prompt_embeds)
|
| 172 |
+
else:
|
| 173 |
+
pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1)
|
| 174 |
+
dtype = prompt_embeds.dtype
|
| 175 |
+
device = prompt_embeds.device
|
| 176 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 177 |
+
|
| 178 |
+
total_seq_len = img_seq_len + prompt_embeds.shape[1]
|
| 179 |
+
text_seq_len = text_mask.shape[1]
|
| 180 |
+
attention_mask = torch.zeros(
|
| 181 |
+
len(text_features), 1, 1, total_seq_len,
|
| 182 |
+
device=text_mask.device,
|
| 183 |
+
dtype=text_mask.dtype
|
| 184 |
+
)
|
| 185 |
+
# Fill in text portion: where text_mask==0, set to -inf
|
| 186 |
+
attention_mask[:, :, :, :text_seq_len] = (1 - text_mask).unsqueeze(1).unsqueeze(2) * -10000.0
|
| 187 |
+
|
| 188 |
+
entropy = -(attn * torch.log(attn + 1e-8)).sum(dim=-1)
|
| 189 |
+
mask_expanded = text_mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
| 190 |
+
mask_expanded = mask_expanded.reshape(len(text_features) * self.num_heads, text_seq_len)
|
| 191 |
+
valid_entropy = entropy[mask_expanded.bool()]
|
| 192 |
+
|
| 193 |
+
return prompt_embeds, attention_mask, pooled_prompt_embeds, text_ids, valid_entropy
|
| 194 |
+
# return prompt_embeds, pooled_prompt_embeds, text_ids, None
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
import torch
|
| 198 |
+
import torch.nn as nn
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class RMSNorm(nn.Module):
|
| 202 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.eps = eps
|
| 205 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 206 |
+
|
| 207 |
+
def _norm(self, x):
|
| 208 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 209 |
+
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
output = self._norm(x.float()).type_as(x)
|
| 212 |
+
return output * self.weight
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class AdaLayerNorm(nn.Module):
|
| 216 |
+
def __init__(self, embedding_dim: int, time_embedding_dim=4096):
|
| 217 |
+
super().__init__()
|
| 218 |
+
|
| 219 |
+
if time_embedding_dim is None:
|
| 220 |
+
time_embedding_dim = embedding_dim
|
| 221 |
+
|
| 222 |
+
self.silu = nn.SiLU()
|
| 223 |
+
self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
|
| 224 |
+
nn.init.normal_(self.linear.weight, mean=0, std=0.02)
|
| 225 |
+
nn.init.zeros_(self.linear.bias)
|
| 226 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
| 227 |
+
|
| 228 |
+
def forward(
|
| 229 |
+
self, x: torch.Tensor, timestep_embedding: torch.Tensor
|
| 230 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 231 |
+
emb = self.linear(self.silu(timestep_embedding))
|
| 232 |
+
shift, scale = emb.unsqueeze(1).chunk(2, dim=-1)
|
| 233 |
+
x = self.norm(x) * (1 + scale) + shift
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class GateMLP(nn.Module):
|
| 238 |
+
def __init__(self, gate_mode='soft', input_dim=64, hidden_dim=1024):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.gate_mode = gate_mode
|
| 241 |
+
hidden_dim = max(input_dim, min(hidden_dim, 512))
|
| 242 |
+
hidden_dim = 512
|
| 243 |
+
self.input_norm = nn.LayerNorm(4096)
|
| 244 |
+
|
| 245 |
+
self.norm0 = nn.LayerNorm(input_dim)
|
| 246 |
+
self.linear1 = nn.Linear(input_dim, hidden_dim)
|
| 247 |
+
self.activation1 = nn.GELU()
|
| 248 |
+
self.linear2 = nn.Linear(hidden_dim+4096, hidden_dim)
|
| 249 |
+
self.activation2 = nn.GELU()
|
| 250 |
+
self.linear3 = nn.Linear(hidden_dim+4096, hidden_dim)
|
| 251 |
+
self.activation3 = nn.GELU()
|
| 252 |
+
self.final_linear = nn.Linear(hidden_dim, 1)
|
| 253 |
+
|
| 254 |
+
nn.init.xavier_uniform_(self.linear1.weight)
|
| 255 |
+
nn.init.zeros_(self.linear1.bias)
|
| 256 |
+
|
| 257 |
+
nn.init.xavier_uniform_(self.linear2.weight)
|
| 258 |
+
nn.init.zeros_(self.linear2.bias)
|
| 259 |
+
|
| 260 |
+
nn.init.xavier_uniform_(self.linear3.weight)
|
| 261 |
+
nn.init.zeros_(self.linear3.bias)
|
| 262 |
+
|
| 263 |
+
nn.init.zeros_(self.final_linear.weight)
|
| 264 |
+
bias_val = 0.0 if 'soft' in gate_mode else 1.0
|
| 265 |
+
nn.init.constant_(self.final_linear.bias, bias_val)
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
y = x.transpose(1, 2).flatten(2)
|
| 269 |
+
y = self.input_norm(y.detach()).unsqueeze(1).repeat(1, x.shape[1],1,1)
|
| 270 |
+
x = self.linear1(self.norm0(x.detach()))
|
| 271 |
+
x = self.activation1(x)
|
| 272 |
+
x = self.linear2(torch.cat([x, y], dim=-1))
|
| 273 |
+
x = self.activation2(x)
|
| 274 |
+
x = self.linear3(torch.cat([x,y], dim=-1))
|
| 275 |
+
x = self.activation3(x)
|
| 276 |
+
x = self.final_linear(x)
|
| 277 |
+
return x
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class CrossAttentionWithInfluence(nn.Module):
|
| 281 |
+
def __init__(self, d_model=4096, num_heads=32, gate_mode='hard'):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.d_model = d_model
|
| 284 |
+
self.num_heads = num_heads
|
| 285 |
+
self.head_dim = d_model // num_heads
|
| 286 |
+
self.gate_mode = gate_mode
|
| 287 |
+
|
| 288 |
+
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
| 289 |
+
|
| 290 |
+
# Linear projections for Q, K, V
|
| 291 |
+
# self.q_proj = nn.Linear(d_model, d_model)
|
| 292 |
+
# self.k_proj = nn.Linear(d_model, d_model)
|
| 293 |
+
self.v_proj = nn.Linear(d_model, d_model)
|
| 294 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 295 |
+
|
| 296 |
+
# nn.init.normal_(self.q_proj.weight, mean=0, std=0.02)
|
| 297 |
+
# nn.init.normal_(self.k_proj.weight, mean=0, std=0.02)
|
| 298 |
+
# nn.init.zeros_(self.q_proj.bias)
|
| 299 |
+
# nn.init.zeros_(self.k_proj.bias)
|
| 300 |
+
nn.init.eye_(self.out_proj.weight)
|
| 301 |
+
nn.init.zeros_(self.out_proj.bias)
|
| 302 |
+
nn.init.eye_(self.v_proj.weight)
|
| 303 |
+
nn.init.zeros_(self.v_proj.bias)
|
| 304 |
+
|
| 305 |
+
self.mask_mlp = GateMLP(input_dim=d_model // num_heads, hidden_dim=1024, gate_mode=gate_mode)
|
| 306 |
+
|
| 307 |
+
self.scale = self.head_dim ** -0.5
|
| 308 |
+
# self.learnable_scale_norm = nn.Parameter(torch.ones([1, 1,1,1])*0.01, requires_grad=True)
|
| 309 |
+
|
| 310 |
+
self.rec_mlp = nn.Sequential(nn.Linear(4096, 4096), nn.SiLU(),
|
| 311 |
+
nn.Linear(4096, 4096), nn.SiLU(),
|
| 312 |
+
nn.Linear(4096, 4096)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
def forward(self, x, y, y_mask, temperature=None, threshold=None, topk=None):
|
| 316 |
+
"""
|
| 317 |
+
Args:
|
| 318 |
+
x: shared embedding [b, 300, 4096]
|
| 319 |
+
y: changing embedding [b, 300, 4096]
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
output: [b, 300, 4096]
|
| 323 |
+
y_influence: [b, 32, 300, 300] - influence from y to x
|
| 324 |
+
"""
|
| 325 |
+
b, seq_len_x, d_model = x.shape
|
| 326 |
+
b, seq_len_y, d_model_y = y.shape
|
| 327 |
+
|
| 328 |
+
"""
|
| 329 |
+
# Q from x only
|
| 330 |
+
Q = self.q_proj(x) # [b, 300, 4096]
|
| 331 |
+
seq_len = Q.shape[1]
|
| 332 |
+
|
| 333 |
+
# K, V from concatenation of [x, y]
|
| 334 |
+
K = self.k_proj(x) # [b, 300, 4096]
|
| 335 |
+
# Reshape for multi-head attention
|
| 336 |
+
Q = Q.view(b, Q.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128]
|
| 337 |
+
K = K.view(b, K.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 600, 128]
|
| 338 |
+
|
| 339 |
+
"""
|
| 340 |
+
V = self.v_proj(y) # [b, 300, 4096]
|
| 341 |
+
shared_V = self.v_proj(x) # [b, 300, 4096]
|
| 342 |
+
|
| 343 |
+
textual_concepts = V.view(b, V.shape[1], self.num_heads, self.head_dim).transpose(1, 2) # [b, 32, 300, 128]
|
| 344 |
+
shared_concepts = shared_V.view(b, shared_V.shape[1], self.num_heads, self.head_dim).transpose(1,
|
| 345 |
+
2) # [b, 32, 300, 128]
|
| 346 |
+
expand_y_mask = y_mask.unsqueeze(1).unsqueeze(-1) # [b, 1, 300, 1]
|
| 347 |
+
# Compute attention scores
|
| 348 |
+
"""
|
| 349 |
+
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [b, 32, 300, 300]
|
| 350 |
+
attn_weights = F.softmax(attn_scores, dim=-1) # [b, 32, 300, 300]
|
| 351 |
+
|
| 352 |
+
# Compute output
|
| 353 |
+
attn_output = torch.matmul(attn_weights, textual_concepts) # [b, 32, 300, 128]
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
diagonal_influence = self.mask_mlp((textual_concepts))
|
| 357 |
+
if 'soft' in self.gate_mode:
|
| 358 |
+
diagonal_influence = 2 * (torch.sigmoid(diagonal_influence * temperature)) # [b, 32, 300, 1]
|
| 359 |
+
diagonal_influence = (diagonal_influence > 0.1).to(
|
| 360 |
+
diagonal_influence.dtype) * diagonal_influence # Thresholding
|
| 361 |
+
soft_influence = diagonal_influence
|
| 362 |
+
else:
|
| 363 |
+
soft_influence = torch.sigmoid(diagonal_influence * temperature)
|
| 364 |
+
if threshold is None:
|
| 365 |
+
threshold = 0.5
|
| 366 |
+
else:
|
| 367 |
+
print('Using custom threshold for influence gating:', threshold)
|
| 368 |
+
hard_influence = (soft_influence >= threshold)
|
| 369 |
+
diagonal_influence = hard_influence + soft_influence - soft_influence.detach() # Straight-through estimator
|
| 370 |
+
|
| 371 |
+
if topk is not None:
|
| 372 |
+
print(diagonal_influence.shape, ' <<< shape before topk ')
|
| 373 |
+
top_k_values, top_k_indices = torch.topk(diagonal_influence, topk, dim=1)
|
| 374 |
+
result = torch.zeros_like(diagonal_influence)
|
| 375 |
+
result.scatter_(1, top_k_indices, top_k_values)
|
| 376 |
+
diagonal_influence = result
|
| 377 |
+
print('Applied top-k sparsification on influence gates with k=', topk)
|
| 378 |
+
|
| 379 |
+
diagonal_output = textual_concepts * diagonal_influence + shared_concepts * (
|
| 380 |
+
1 - diagonal_influence) # [b, 32, 300, 128]
|
| 381 |
+
da,db,dc,dd = diagonal_output.shape
|
| 382 |
+
rec_diagonal = self.rec_mlp(diagonal_output.transpose(1,2).flatten(2)[y_mask.bool()].to(x.dtype))
|
| 383 |
+
tgt_diagonal = y[y_mask.bool()]
|
| 384 |
+
|
| 385 |
+
diagonal_output = expand_y_mask * diagonal_output + (1 - expand_y_mask) * shared_concepts # [b, 32, 300, 128]
|
| 386 |
+
|
| 387 |
+
mask_bool_expanded = expand_y_mask.expand_as(diagonal_influence).bool() # [b, 32, 300, 1]
|
| 388 |
+
meaningful_gates = diagonal_influence[mask_bool_expanded]
|
| 389 |
+
soft_meaningful_gate = soft_influence[mask_bool_expanded]
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# full_output = self.learnable_scale_norm*attn_output + diagonal_output # [b, 32, 300, 128]
|
| 393 |
+
full_output = diagonal_output.to(x.dtype)
|
| 394 |
+
|
| 395 |
+
# Reshape back
|
| 396 |
+
full_output = full_output.transpose(1, 2).contiguous().view(b, y.shape[1], d_model) # [b, 300, 4096]
|
| 397 |
+
full_output = full_output # Residual connection
|
| 398 |
+
|
| 399 |
+
# Final output projection
|
| 400 |
+
output = self.out_proj(full_output) # [b, 300, 4096]
|
| 401 |
+
|
| 402 |
+
return output, diagonal_influence.squeeze(-1).transpose(1, 2), meaningful_gates, soft_meaningful_gate, rec_diagonal, tgt_diagonal
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def init_weights_gaussian(model, mean=0.0, std=0.02):
|
| 411 |
+
"""
|
| 412 |
+
Initialize all nn.Linear layers in the model:
|
| 413 |
+
- weights with Gaussian(mean, std)
|
| 414 |
+
- biases to 0
|
| 415 |
+
"""
|
| 416 |
+
for m in model.modules():
|
| 417 |
+
if isinstance(m, nn.Linear):
|
| 418 |
+
nn.init.normal_(m.weight, mean=mean, std=std)
|
| 419 |
+
if m.bias is not None:
|
| 420 |
+
nn.init.constant_(m.bias, 0.0)
|
| 421 |
+
|
| 422 |
+
class ConceptAligner(nn.Module):
|
| 423 |
+
def __init__(self, per_dim=4):
|
| 424 |
+
super().__init__()
|
| 425 |
+
empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu')
|
| 426 |
+
self.register_buffer('empty_pooled_clip', empty_pooled_clip)
|
| 427 |
+
|
| 428 |
+
test_eps = torch.randn([1, 300, per_dim], dtype=torch.bfloat16).to('cpu')*0.7
|
| 429 |
+
self.register_buffer('test_eps', test_eps)
|
| 430 |
+
|
| 431 |
+
self.init_proj = nn.Sequential(nn.Linear(768, 300*16), nn.SiLU())
|
| 432 |
+
self.proj = nn.Sequential(nn.Linear(16, 1024), nn.SiLU(),
|
| 433 |
+
nn.Linear(1024, 1024), nn.SiLU())
|
| 434 |
+
self.text_proj = nn.Sequential(nn.Linear(4096, 1024), nn.SiLU(),
|
| 435 |
+
nn.Linear(1024, 1024), nn.SiLU())
|
| 436 |
+
self.proj_mu = nn.Sequential(nn.Linear(1024, per_dim))
|
| 437 |
+
self.proj_logvar = nn.Sequential(nn.Linear(1024, per_dim))
|
| 438 |
+
|
| 439 |
+
self.eps_proj = nn.Sequential(nn.Linear(per_dim, 1024), nn.SiLU(),
|
| 440 |
+
nn.LayerNorm(1024),
|
| 441 |
+
nn.Linear(1024, 4096))
|
| 442 |
+
|
| 443 |
+
init_weights_gaussian(self, mean=0.0, std=0.02)
|
| 444 |
+
torch.nn.init.constant_(self.eps_proj[-1].weight, 0.0)
|
| 445 |
+
torch.nn.init.constant_(self.eps_proj[-1].bias, 0.0)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@property
|
| 449 |
+
def dtype(self):
|
| 450 |
+
"""Return the dtype of the model parameters."""
|
| 451 |
+
# return next(self.parameters()).dtype
|
| 452 |
+
return torch.bfloat16
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def device(self):
|
| 456 |
+
"""Return the device of the model parameters."""
|
| 457 |
+
# return next(self.parameters()).device
|
| 458 |
+
return self.empty_pooled_clip.device
|
| 459 |
+
|
| 460 |
+
def forward(self, text_features, image_features=None, eps=None):
|
| 461 |
+
|
| 462 |
+
#return text_features, None, self.empty_pooled_clip.expand(text_features.shape[0], -1), torch.zeros(text_features.shape[1], 3).to(device=text_features.device, dtype=text_features.dtype), {'mu': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype), 'logvar': torch.zeros([1,300,1], device=text_features.device, dtype=text_features.dtype)}
|
| 463 |
+
|
| 464 |
+
dtype = text_features.dtype
|
| 465 |
+
device = text_features.device
|
| 466 |
+
|
| 467 |
+
if image_features is not None:
|
| 468 |
+
visual_hidden = self.proj(self.init_proj(image_features).view(len(image_features), 300, -1))
|
| 469 |
+
text_hidden = self.text_proj(text_features.detach())
|
| 470 |
+
hidden = visual_hidden - text_hidden
|
| 471 |
+
mu = self.proj_mu(hidden)
|
| 472 |
+
logvar = self.proj_logvar(hidden)
|
| 473 |
+
eps = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)
|
| 474 |
+
else:
|
| 475 |
+
if eps is None:
|
| 476 |
+
eps = self.test_eps.to(device=device, dtype=dtype)
|
| 477 |
+
mu = torch.zeros_like(eps)
|
| 478 |
+
logvar = torch.zeros_like(eps)
|
| 479 |
+
|
| 480 |
+
proj_eps = self.eps_proj(eps)
|
| 481 |
+
prompt_embeds = text_features + proj_eps
|
| 482 |
+
pooled_prompt_embeds = self.empty_pooled_clip.expand(text_features.shape[0], -1)
|
| 483 |
+
|
| 484 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 485 |
+
aux_info = {
|
| 486 |
+
'mu': mu,
|
| 487 |
+
'logvar': logvar,
|
| 488 |
+
'eps': eps
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
return prompt_embeds, None, pooled_prompt_embeds, text_ids, aux_info
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
if __name__ == '__main__':
|
| 499 |
+
from transformers import AutoProcessor
|
| 500 |
+
from diffusers import FluxPipeline
|
| 501 |
+
import os
|
| 502 |
+
from PIL import Image
|
| 503 |
+
def create_image_grid(images, cols):
|
| 504 |
+
rows = (len(images) + cols - 1) // cols
|
| 505 |
+
w, h = images[0].size
|
| 506 |
+
grid = Image.new('RGB', (cols * w, rows * h))
|
| 507 |
+
for i, img in enumerate(images):
|
| 508 |
+
grid.paste(img, (i % cols * w, i // cols * h))
|
| 509 |
+
return grid
|
| 510 |
+
|
| 511 |
+
dim = 4096
|
| 512 |
+
num_heads = 32
|
| 513 |
+
dtype = torch.bfloat16
|
| 514 |
+
model = ConceptAligner().to('cuda').to(dtype)
|
| 515 |
+
x = torch.randn([5, 300, dim]).to('cuda').to(dtype)
|
| 516 |
+
y = torch.randn([5, 300, dim]).to('cuda').to(dtype)
|
| 517 |
+
i = torch.randn([5,768]).to('cuda').to(dtype)
|
| 518 |
+
y[1] = y[0]
|
| 519 |
+
m = torch.ones([5, 300]).to('cuda').to(dtype)
|
| 520 |
+
m[:3,:128] = 0
|
| 521 |
+
prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(x, i)
|
| 522 |
+
print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape)
|
| 523 |
+
print(prompt_embeds.shape, ' ', pooled_prompt_embeds.shape, ' ', text_ids.shape)
|
| 524 |
+
for k in aux_info:
|
| 525 |
+
print(k, aux_info[k].shape, aux_info[k].min(), aux_info[k].max(), aux_info[k].mean())
|
| 526 |
+
|
| 527 |
+
from text_encoder import LoraT5Embedder
|
| 528 |
+
from datasets import load_dataset
|
| 529 |
+
dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]')
|
| 530 |
+
item = dataset[0:4]
|
| 531 |
+
another_item = dataset[0:4]
|
| 532 |
+
from diffusers.models.normalization import RMSNorm
|
| 533 |
+
clip_processor = AutoProcessor.from_pretrained("./clip-vit-large-patch14")
|
| 534 |
+
clip_images = clip_processor(images=item['image'], return_tensors="pt").pixel_values.to('cuda:0').to(dtype)
|
| 535 |
+
texts = []
|
| 536 |
+
texts.append("""A heartwarming 3D rendered scene of
|
| 537 |
+
an elderly farmer and a tiny orange
|
| 538 |
+
kitten. The farmer, with a gentle smile,
|
| 539 |
+
walks alongside the kitten in a lush,
|
| 540 |
+
green garden filled with thriving plants,
|
| 541 |
+
showcasing a fruitful harvest. The
|
| 542 |
+
intricate details of the overalls and the
|
| 543 |
+
farmer's worn, weathered face tell a
|
| 544 |
+
story of years spent tending to the land. the farmer is wearing a blue shirt""")
|
| 545 |
+
texts.append("""A unique, intricately detailed creature
|
| 546 |
+
resembling a reptile, possibly a lizard or
|
| 547 |
+
a gecko. It has a vibrant blue and green
|
| 548 |
+
scaled body, with large, round, and
|
| 549 |
+
expressive eyes that are a deep shade of
|
| 550 |
+
blue. The backdrop is a
|
| 551 |
+
soft, blurred forest setting, suggesting a
|
| 552 |
+
serene and mystical ambiance. the creature is wearing a golden crown""")
|
| 553 |
+
texts.append("""Deep in the enchanted forest lives a woman
|
| 554 |
+
who is the moon fairy. Her long blonde hair
|
| 555 |
+
shines in the starlight, tangled with her flowers
|
| 556 |
+
that glow with a soft blue glow. Her eyes are
|
| 557 |
+
the color of the night and shine with the magic
|
| 558 |
+
of the night. The fairy wears a dress made of
|
| 559 |
+
moon petals, woven with threads of moonlight
|
| 560 |
+
that shine with an iridescent glow, a crown of
|
| 561 |
+
stars adorns her head, shining with the light of
|
| 562 |
+
the full moon that illuminates the forest. Her
|
| 563 |
+
wings are translucent like glass, with a pale
|
| 564 |
+
glow reminiscent of the glow of the moon. HD,
|
| 565 |
+
6K, photo, cinematic, poster""")
|
| 566 |
+
|
| 567 |
+
texts.append(
|
| 568 |
+
"""In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""")
|
| 569 |
+
|
| 570 |
+
text_encoder = LoraT5Embedder(device='cuda').to(dtype)
|
| 571 |
+
text_features, _, _, _, image_features, _ = text_encoder(texts, clip_images)
|
| 572 |
+
print(text_features.shape, image_features.shape, ' >>>>>>>>> text input')
|
| 573 |
+
images = []
|
| 574 |
+
pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16)
|
| 575 |
+
pipe.to('cuda')
|
| 576 |
+
|
| 577 |
+
for txt_feat, img_feat in zip(text_features, image_features):
|
| 578 |
+
|
| 579 |
+
prompt_embeds, _, pooled_prompt_embeds, text_ids, aux_info = model(txt_feat.unsqueeze(0), img_feat.unsqueeze(0))
|
| 580 |
+
image = pipe(
|
| 581 |
+
prompt_embeds=prompt_embeds,
|
| 582 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 583 |
+
height=512,
|
| 584 |
+
width=512,
|
| 585 |
+
guidance_scale=3.5,
|
| 586 |
+
num_inference_steps=20,
|
| 587 |
+
max_sequence_length=512,
|
| 588 |
+
generator=torch.Generator("cuda").manual_seed(1995),
|
| 589 |
+
).images[0]
|
| 590 |
+
images.append(image)
|
| 591 |
+
|
| 592 |
+
aligned_image = create_image_grid(images, cols=len(images) // 2)
|
| 593 |
+
os.makedirs('samples', exist_ok=True)
|
| 594 |
+
aligned_image.save("samples/image%.jpg")
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
raise SystemExit
|
| 598 |
+
|
| 599 |
+
influence_matrix = aux_info['influence']
|
| 600 |
+
bin_influence_matrix = (influence_matrix > 0.1).float()
|
| 601 |
+
mean_alive = bin_influence_matrix.sum(dim=-1).mean()
|
| 602 |
+
max_alive = bin_influence_matrix.sum(dim=-1).max()
|
| 603 |
+
min_alive = bin_influence_matrix.sum(dim=-1).min()
|
| 604 |
+
max_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).max()
|
| 605 |
+
mean_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).mean()
|
| 606 |
+
min_token_alive = ((bin_influence_matrix.sum(dim=-1) > 0).float().sum(dim=-1)).min()
|
| 607 |
+
|
| 608 |
+
print(
|
| 609 |
+
f"Mean alive heads per token: {mean_alive:.2f}, Max alive heads per token: {max_alive:.2f}, Min alive heads per token: {min_alive:.2f}")
|
| 610 |
+
print(
|
| 611 |
+
f"Mean alive tokens: {mean_token_alive:.2f}, Max alive tokens: {max_token_alive:.2f}, Min alive tokens: {min_token_alive:.2f}")
|
| 612 |
+
|
| 613 |
+
import os
|
| 614 |
+
|
| 615 |
+
CHECKPOINT_PATH = 'runs/00393/checkpoint-6000'
|
| 616 |
+
from safetensors.torch import load_file
|
| 617 |
+
|
| 618 |
+
# Load adapter (model.safetensors)
|
| 619 |
+
adapter_path = os.path.join(CHECKPOINT_PATH, "model_1.safetensors")
|
| 620 |
+
if os.path.exists(adapter_path):
|
| 621 |
+
adapter_state = load_file(adapter_path)
|
| 622 |
+
model.load_state_dict(adapter_state, strict=True)
|
| 623 |
+
print("Adapter loaded successfully!")
|
| 624 |
+
|
| 625 |
+
print(model.influence_net.v_proj.weight, ' <<< weight ')
|
| 626 |
+
print(model.influence_net.v_proj.bias, ' <<< bias ')
|
| 627 |
+
print(model.influence_net.out_proj.weight, ' <<< out weight ')
|
| 628 |
+
print(model.influence_net.out_proj.bias, ' <<< out bias ')
|
| 629 |
+
print(model.influence_net.mask_mlp.linear3.weight, ' <<< gate weight 3 ')
|
| 630 |
+
print(model.influence_net.mask_mlp.linear3.bias, ' <<< gate bias ')
|
| 631 |
+
|
| 632 |
+
z = torch.randn([3, num_heads, 300, 4096 // num_heads]).to('cuda').to(dtype)
|
| 633 |
+
gate_values = model.influence_net.mask_mlp(z)
|
| 634 |
+
gate_values = 2 * (torch.sigmoid(gate_values))
|
| 635 |
+
|
| 636 |
+
print(gate_values, ' <<< gate values ', gate_values.shape, ' ', torch.mean(gate_values))
|
| 637 |
+
|
| 638 |
+
from diffusers import FluxPipeline
|
| 639 |
+
from PIL import Image
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
reserved_memory = torch.cuda.memory_reserved(0) / (1024 ** 3)
|
| 645 |
+
print(f"Reserved GPU memory: {reserved_memory:.2f} GB")
|
| 646 |
+
|
| 647 |
+
from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel
|
| 648 |
+
import torch
|
| 649 |
+
from text_encoder import LoraT5Embedder
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
text_encoder = LoraT5Embedder(device='cuda').to(torch.bfloat16)
|
| 653 |
+
texts = []
|
| 654 |
+
texts.append("""A heartwarming 3D rendered scene of
|
| 655 |
+
an elderly farmer and a tiny orange
|
| 656 |
+
kitten. The farmer, with a gentle smile,
|
| 657 |
+
walks alongside the kitten in a lush,
|
| 658 |
+
green garden filled with thriving plants,
|
| 659 |
+
showcasing a fruitful harvest. The
|
| 660 |
+
intricate details of the overalls and the
|
| 661 |
+
farmer's worn, weathered face tell a
|
| 662 |
+
story of years spent tending to the land. the farmer is wearing a blue shirt""")
|
| 663 |
+
texts.append("""A unique, intricately detailed creature
|
| 664 |
+
resembling a reptile, possibly a lizard or
|
| 665 |
+
a gecko. It has a vibrant blue and green
|
| 666 |
+
scaled body, with large, round, and
|
| 667 |
+
expressive eyes that are a deep shade of
|
| 668 |
+
blue. The backdrop is a
|
| 669 |
+
soft, blurred forest setting, suggesting a
|
| 670 |
+
serene and mystical ambiance. the creature is wearing a golden crown""")
|
| 671 |
+
texts.append("""Deep in the enchanted forest lives a woman
|
| 672 |
+
who is the moon fairy. Her long blonde hair
|
| 673 |
+
shines in the starlight, tangled with her flowers
|
| 674 |
+
that glow with a soft blue glow. Her eyes are
|
| 675 |
+
the color of the night and shine with the magic
|
| 676 |
+
of the night. The fairy wears a dress made of
|
| 677 |
+
moon petals, woven with threads of moonlight
|
| 678 |
+
that shine with an iridescent glow, a crown of
|
| 679 |
+
stars adorns her head, shining with the light of
|
| 680 |
+
the full moon that illuminates the forest. Her
|
| 681 |
+
wings are translucent like glass, with a pale
|
| 682 |
+
glow reminiscent of the glow of the moon. HD,
|
| 683 |
+
6K, photo, cinematic, poster""")
|
| 684 |
+
|
| 685 |
+
texts.append(
|
| 686 |
+
"""In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a clear blue sky outside, with the silhouettes of trees swaying gently in the distance. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and homey, capturing a tranquil afternoon moment of quiet observation.""")
|
| 687 |
+
texts.append(
|
| 688 |
+
"""In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is bare, revealing the natural curve of its muscles and the sheen of its coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and quiet power.""")
|
| 689 |
+
INDEX = 0
|
| 690 |
+
text = texts[INDEX]
|
| 691 |
+
with torch.no_grad():
|
| 692 |
+
floral_embeds, _,_,_,_,attn_mask = text_encoder(text, )
|
| 693 |
+
print(attn_mask.shape, ' >>>> ', attn_mask)
|
| 694 |
+
print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ')
|
| 695 |
+
nopad_floral_embeds, nopad_shared_embeds, nopad_attn_mask = text_encoder(text, padding=False)
|
| 696 |
+
print(floral_embeds.shape, shared_embeds.shape, ' >>>> floral ')
|
| 697 |
+
|
| 698 |
+
"""
|
| 699 |
+
_,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False)
|
| 700 |
+
print(aux_info['meaningful_influence'].shape, ' <<< influence shape ', aux_info['meaningful_influence'][:100],' ',torch.mean(aux_info['meaningful_influence']))
|
| 701 |
+
floral_embeds, shared_embeds, attn_mask = text_encoder([""], padding='max_length')
|
| 702 |
+
_,_,_,_,aux_info = model(floral_embeds, shared_embeds, attn_mask, is_training=False)
|
| 703 |
+
print(aux_info['meaningful_influence'].shape, ' <<< empty influence shape ', aux_info['meaningful_influence'],' ',torch.mean(aux_info['meaningful_influence']))
|
| 704 |
+
raise SystemExit
|
| 705 |
+
"""
|
| 706 |
+
|
| 707 |
+
text2s = []
|
| 708 |
+
text2s.append("""A heartwarming 3D rendered scene of
|
| 709 |
+
an elderly farmer and a tiny orange
|
| 710 |
+
kitten. The farmer, with a gentle smile,
|
| 711 |
+
walks alongside the kitten in a lush,
|
| 712 |
+
green garden filled with thriving plants,
|
| 713 |
+
showcasing a fruitful harvest. The
|
| 714 |
+
intricate details of the overalls and the
|
| 715 |
+
farmer's worn, weathered face tell a
|
| 716 |
+
story of years spent tending to the land. the farmer is wearing a red shirt""")
|
| 717 |
+
text2s.append("""A unique, intricately detailed creature
|
| 718 |
+
resembling a reptile, possibly a lizard or
|
| 719 |
+
a gecko. It has a vibrant blue and green
|
| 720 |
+
scaled body, with large, round, and
|
| 721 |
+
expressive eyes that are a deep shade of
|
| 722 |
+
blue. The backdrop is a
|
| 723 |
+
soft, blurred forest setting, suggesting a
|
| 724 |
+
serene and mystical ambiance. the creature is wearing a floral crown""")
|
| 725 |
+
text2s.append("""Deep in the enchanted forest lives a woman
|
| 726 |
+
who is the moon fairy. Her long black hair
|
| 727 |
+
shines in the starlight, tangled with her flowers
|
| 728 |
+
that glow with a soft blue glow. Her eyes are
|
| 729 |
+
the color of the night and shine with the magic
|
| 730 |
+
of the night. The fairy wears a dress made of
|
| 731 |
+
moon petals, woven with threads of moonlight
|
| 732 |
+
that shine with an iridescent glow, a crown of
|
| 733 |
+
stars adorns her head, shining with the light of
|
| 734 |
+
the full moon that illuminates the forest. Her
|
| 735 |
+
wings are translucent like glass, with a pale
|
| 736 |
+
glow reminiscent of the glow of the moon. HD,
|
| 737 |
+
6K, photo, cinematic, poster""")
|
| 738 |
+
text2s.append(
|
| 739 |
+
"""In the image, a fluffy white cat sits peacefully on a windowsill surrounded by potted green plants. Sunlight filters through sheer white curtains, casting soft golden patterns across its fur. The window reveals a gray, rainy sky outside, with raindrops streaking down the glass and blurred trees beyond. The cat’s posture is calm and elegant, its tail curled neatly around its paws. The atmosphere is serene and introspective, capturing a cozy moment of quiet observation during a rainy afternoon.""")
|
| 740 |
+
text2s.append(
|
| 741 |
+
"""In the image, a majestic white horse gallops across a misty meadow at sunrise. Its mane and tail flow freely in the golden light, and the air glows softly with early morning haze. The horse’s body is adorned with a bright red saddle, contrasting sharply against its white coat. Dew sparkles on the grass beneath its hooves, and the distant trees fade into pale gold mist. The scene conveys freedom, grace, and a striking touch of color that adds visual drama.""")
|
| 742 |
+
text2 = text2s[INDEX]
|
| 743 |
+
|
| 744 |
+
with torch.no_grad():
|
| 745 |
+
golden_embeds, shared_embeds, golden_mask = text_encoder(text2, padding='max_length')
|
| 746 |
+
print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ')
|
| 747 |
+
nopad_golden_embeds, nopad_shared_embeds, nopad_golden_mask = text_encoder(text2, padding=False)
|
| 748 |
+
print(golden_embeds.shape, shared_embeds.shape, ' >>>> golden ')
|
| 749 |
+
|
| 750 |
+
batch_encoding = text_encoder.t5_tokenizer(
|
| 751 |
+
text,
|
| 752 |
+
truncation=True,
|
| 753 |
+
max_length=text_encoder.max_length,
|
| 754 |
+
return_tensors="pt",
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
input_ids = batch_encoding["input_ids"][0] # Get the token IDs
|
| 758 |
+
|
| 759 |
+
# Convert token IDs back to tokens to see what they are
|
| 760 |
+
tokens_floral = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids)
|
| 761 |
+
|
| 762 |
+
batch_encoding = text_encoder.t5_tokenizer(
|
| 763 |
+
text2,
|
| 764 |
+
truncation=True,
|
| 765 |
+
max_length=text_encoder.max_length,
|
| 766 |
+
return_tensors="pt",
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
input_ids = batch_encoding["input_ids"][0] # Get the token IDs
|
| 770 |
+
tokens_golden = text_encoder.t5_tokenizer.convert_ids_to_tokens(input_ids)
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
# Convert token IDs back to tokens to see what they are
|
| 774 |
+
|
| 775 |
+
# Find the index of specific words
|
| 776 |
+
def find_token_indices(tokens, word):
|
| 777 |
+
"""Find all indices where a word or its token appears"""
|
| 778 |
+
indices = []
|
| 779 |
+
# T5 tokenizer might split words or add special characters
|
| 780 |
+
word_token = text_encoder.t5_tokenizer.encode(word, add_special_tokens=False)[0]
|
| 781 |
+
word_token_str = text_encoder.t5_tokenizer.convert_ids_to_tokens([word_token])[0]
|
| 782 |
+
|
| 783 |
+
for i, token in enumerate(tokens):
|
| 784 |
+
if token == word_token_str or word.lower() in token.lower():
|
| 785 |
+
indices.append(i)
|
| 786 |
+
return indices
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
key1s = ['blue', 'golden', 'blonde', 'clear', 'horse']
|
| 790 |
+
key2s = ['red', 'floral', 'black', 'rainy', 'red']
|
| 791 |
+
|
| 792 |
+
# Find indices for "blue"
|
| 793 |
+
blue_indices = find_token_indices(tokens_floral, key1s[INDEX])[-1]
|
| 794 |
+
print(f"Indices for 'blue': {blue_indices}")
|
| 795 |
+
|
| 796 |
+
# Find indices for "red" (won't be found in this text)
|
| 797 |
+
red_indices = find_token_indices(tokens_golden, key2s[INDEX])[-1]
|
| 798 |
+
print(f"Indices for 'red': {red_indices}")
|
| 799 |
+
|
| 800 |
+
pipe = FluxPipeline.from_pretrained("./FLUX.1-dev", dtype=torch.bfloat16, text_encoder=None).to(torch.bfloat16)
|
| 801 |
+
pipe.to('cuda')
|
| 802 |
+
adapter_path = os.path.join(CHECKPOINT_PATH, "model.safetensors")
|
| 803 |
+
if os.path.exists(adapter_path):
|
| 804 |
+
adapter_state = load_file(adapter_path)
|
| 805 |
+
pipe.transformer.load_state_dict(adapter_state, strict=True)
|
| 806 |
+
print("Transformer loaded successfully!")
|
| 807 |
+
|
| 808 |
+
images = []
|
| 809 |
+
empty_pooled_clip = torch.load('empty_pooled_clip.pt', map_location='cpu').to('cuda').to(torch.bfloat16)
|
| 810 |
+
|
| 811 |
+
print("Generating image with concatenation...")
|
| 812 |
+
images = []
|
| 813 |
+
# for cur_prompt_embed in [floral_embeds, nopad_floral_embeds
|
| 814 |
+
# , inter_embed, golden_embeds, nopad_golden_embeds]:
|
| 815 |
+
|
| 816 |
+
# for (start_dim, end_dim) in [(0,4096), (1024,4096), (2048, 4096), (1024, 2048)]:
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
for emb in ['floral', 'golden']:
|
| 820 |
+
for temp in [2.5]:
|
| 821 |
+
for thr in [-1, 0.5, 0.75, 0.85, 0.95]:
|
| 822 |
+
for topk in [None]:
|
| 823 |
+
print('>>>> Temperature: ', temp, topk)
|
| 824 |
+
if 'floral' in emb:
|
| 825 |
+
inter_embed, _, _, _, new_aux_info = model(floral_embeds, shared_embeds, attn_mask,
|
| 826 |
+
is_training=False, temperature=temp,
|
| 827 |
+
threshold=thr, topk=topk)
|
| 828 |
+
else:
|
| 829 |
+
inter_embed, _, _, _, new_aux_info = model(golden_embeds, shared_embeds, golden_mask,
|
| 830 |
+
is_training=False, temperature=temp,
|
| 831 |
+
threshold=thr, topk=topk)
|
| 832 |
+
|
| 833 |
+
print(new_aux_info['influence'][:, blue_indices].shape, ' >>>> influence ',
|
| 834 |
+
new_aux_info['influence'][:, blue_indices])
|
| 835 |
+
print(new_aux_info['meaningful_influence'], ' >>>> meaningful influence ',
|
| 836 |
+
torch.mean(new_aux_info['meaningful_influence']))
|
| 837 |
+
|
| 838 |
+
# inter_embed = torch.clone(floral_embeds)
|
| 839 |
+
# inter_embed[:, blue_indices] = shared_embeds[:, blue_indices]
|
| 840 |
+
# inter_embed[:, blue_indices, start_dim:end_dim] = floral_embeds[:, blue_indices, start_dim:end_dim]
|
| 841 |
+
|
| 842 |
+
image = pipe(
|
| 843 |
+
prompt_embeds=inter_embed,
|
| 844 |
+
pooled_prompt_embeds=empty_pooled_clip,
|
| 845 |
+
height=512,
|
| 846 |
+
width=512,
|
| 847 |
+
guidance_scale=3.5,
|
| 848 |
+
num_inference_steps=20,
|
| 849 |
+
max_sequence_length=512,
|
| 850 |
+
generator=torch.Generator("cuda").manual_seed(1995),
|
| 851 |
+
).images[0]
|
| 852 |
+
images.append(image)
|
| 853 |
+
aligned_image = create_image_grid(images, cols=len(images) // 2)
|
| 854 |
+
os.makedirs('samples', exist_ok=True)
|
| 855 |
+
aligned_image.save("samples/image%s.jpg" % INDEX)
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
|
app.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
from aligner import ConceptAligner
|
| 6 |
+
from text_encoder import LoraT5Embedder
|
| 7 |
+
from pipeline import CustomFluxKontextPipeline
|
| 8 |
+
from diffusers import FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, AutoencoderKL
|
| 9 |
+
from peft import LoraConfig
|
| 10 |
+
import gradio as gr
|
| 11 |
+
|
| 12 |
+
# Configuration
|
| 13 |
+
MODEL_REPO = "Shaoan/ConceptAligner-Weights" # Your model repo
|
| 14 |
+
CHECKPOINT_DIR = "./checkpoint"
|
| 15 |
+
|
| 16 |
+
def download_checkpoint():
|
| 17 |
+
"""Download checkpoint files from HF model repo"""
|
| 18 |
+
print("Downloading checkpoint files...")
|
| 19 |
+
|
| 20 |
+
files = [
|
| 21 |
+
"model.safetensors",
|
| 22 |
+
"model_1.safetensors",
|
| 23 |
+
"model_2.safetensors"
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
for filename in files:
|
| 29 |
+
local_path = os.path.join(CHECKPOINT_DIR, filename)
|
| 30 |
+
if not os.path.exists(local_path):
|
| 31 |
+
print(f" Downloading {filename}...")
|
| 32 |
+
hf_hub_download(
|
| 33 |
+
repo_id=MODEL_REPO,
|
| 34 |
+
filename=filename,
|
| 35 |
+
local_dir=CHECKPOINT_DIR,
|
| 36 |
+
local_dir_use_symlinks=False
|
| 37 |
+
)
|
| 38 |
+
print(f" ✓ {filename} downloaded")
|
| 39 |
+
|
| 40 |
+
print("✓ All checkpoint files ready!")
|
| 41 |
+
|
| 42 |
+
class ConceptAlignerModel:
|
| 43 |
+
def __init__(self):
|
| 44 |
+
# Download checkpoint first
|
| 45 |
+
download_checkpoint()
|
| 46 |
+
|
| 47 |
+
self.checkpoint_path = CHECKPOINT_DIR
|
| 48 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 49 |
+
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 50 |
+
|
| 51 |
+
self.previous_image = None
|
| 52 |
+
self.previous_prompt = None
|
| 53 |
+
|
| 54 |
+
print(f"\n{'='*60}")
|
| 55 |
+
print(f"Loading ConceptAligner Model")
|
| 56 |
+
print(f"Device: {self.device}")
|
| 57 |
+
print(f"{'='*60}")
|
| 58 |
+
|
| 59 |
+
self.setup_models()
|
| 60 |
+
|
| 61 |
+
def setup_models(self):
|
| 62 |
+
"""Load all models"""
|
| 63 |
+
# Load ConceptAligner
|
| 64 |
+
print(f" Loading ConceptAligner...")
|
| 65 |
+
self.model = ConceptAligner().to(self.device).to(self.dtype)
|
| 66 |
+
adapter_path = os.path.join(self.checkpoint_path, "model_1.safetensors")
|
| 67 |
+
adapter_state = load_file(adapter_path)
|
| 68 |
+
self.model.load_state_dict(adapter_state, strict=True)
|
| 69 |
+
print(f" ✓ Adapter loaded")
|
| 70 |
+
|
| 71 |
+
# Load T5 encoder
|
| 72 |
+
print(f" Loading T5 encoder...")
|
| 73 |
+
self.text_encoder = LoraT5Embedder(device=self.device).to(self.dtype)
|
| 74 |
+
adapter_path = os.path.join(self.checkpoint_path, "model_2.safetensors")
|
| 75 |
+
adapter_state = load_file(adapter_path)
|
| 76 |
+
if "t5_encoder.shared.weight" in adapter_state and "t5_encoder.encoder.embed_tokens.weight" not in adapter_state:
|
| 77 |
+
adapter_state["t5_encoder.encoder.embed_tokens.weight"] = adapter_state["t5_encoder.shared.weight"]
|
| 78 |
+
self.text_encoder.load_state_dict(adapter_state, strict=True)
|
| 79 |
+
print(f" ✓ T5 Adapter loaded")
|
| 80 |
+
|
| 81 |
+
# Load VAE
|
| 82 |
+
print(f" Loading VAE...")
|
| 83 |
+
vae = AutoencoderKL.from_pretrained(
|
| 84 |
+
'black-forest-labs/FLUX.1-dev',
|
| 85 |
+
subfolder="vae",
|
| 86 |
+
torch_dtype=self.dtype
|
| 87 |
+
).to(self.device)
|
| 88 |
+
|
| 89 |
+
# Load transformer
|
| 90 |
+
print(f" Loading transformer...")
|
| 91 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
| 92 |
+
'black-forest-labs/FLUX.1-dev',
|
| 93 |
+
subfolder="transformer",
|
| 94 |
+
torch_dtype=self.dtype
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
target_modules = [
|
| 98 |
+
"attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0",
|
| 99 |
+
"attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out",
|
| 100 |
+
"ff.net.0.proj", "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2",
|
| 101 |
+
"proj_mlp", "proj_out", "norm.linear", "norm1.linear"
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
transformer_lora_config = LoraConfig(
|
| 105 |
+
r=256,
|
| 106 |
+
lora_alpha=256,
|
| 107 |
+
lora_dropout=0.0,
|
| 108 |
+
init_lora_weights="gaussian",
|
| 109 |
+
target_modules=target_modules,
|
| 110 |
+
)
|
| 111 |
+
transformer.add_adapter(transformer_lora_config)
|
| 112 |
+
transformer.context_embedder.requires_grad_(True)
|
| 113 |
+
|
| 114 |
+
# Load fine-tuned transformer
|
| 115 |
+
transformer_path = os.path.join(self.checkpoint_path, "model.safetensors")
|
| 116 |
+
transformer_state = load_file(transformer_path)
|
| 117 |
+
transformer.load_state_dict(transformer_state, strict=True)
|
| 118 |
+
print(f" ✓ Fine-tuned transformer loaded")
|
| 119 |
+
|
| 120 |
+
transformer = transformer.to(self.device)
|
| 121 |
+
|
| 122 |
+
# Load or download empty pooled clip
|
| 123 |
+
empty_clip_path = "empty_pooled_clip.pt"
|
| 124 |
+
if not os.path.exists(empty_clip_path):
|
| 125 |
+
print(" Downloading empty_pooled_clip.pt...")
|
| 126 |
+
hf_hub_download(
|
| 127 |
+
repo_id=MODEL_REPO,
|
| 128 |
+
filename="empty_pooled_clip.pt",
|
| 129 |
+
local_dir=".",
|
| 130 |
+
local_dir_use_symlinks=False
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.empty_pooled_clip = torch.load(empty_clip_path, map_location=self.device).to(self.dtype)
|
| 134 |
+
|
| 135 |
+
# Create pipeline
|
| 136 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
| 137 |
+
'black-forest-labs/FLUX.1-dev', subfolder="scheduler"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.pipe = CustomFluxKontextPipeline(
|
| 141 |
+
scheduler=noise_scheduler,
|
| 142 |
+
aligner=self.model.to(self.device).to(self.dtype),
|
| 143 |
+
transformer=transformer.to(self.device).to(self.dtype),
|
| 144 |
+
vae=vae.to(self.device).to(self.dtype),
|
| 145 |
+
text_embedder=self.text_encoder.to(self.device).to(self.dtype),
|
| 146 |
+
).to(self.device)
|
| 147 |
+
|
| 148 |
+
if torch.cuda.is_available():
|
| 149 |
+
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| 150 |
+
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
| 151 |
+
print(f" ✓ Pipeline ready on {self.device}")
|
| 152 |
+
print(f" 📊 GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
|
| 153 |
+
else:
|
| 154 |
+
print(f" ✓ Pipeline ready on {self.device}")
|
| 155 |
+
|
| 156 |
+
@torch.no_grad()
|
| 157 |
+
def generate_image(
|
| 158 |
+
self,
|
| 159 |
+
prompt,
|
| 160 |
+
threshold=0.0,
|
| 161 |
+
topk=0,
|
| 162 |
+
height=512,
|
| 163 |
+
width=512,
|
| 164 |
+
guidance_scale=3.5,
|
| 165 |
+
true_cf_scale=1.0,
|
| 166 |
+
num_inference_steps=20,
|
| 167 |
+
seed=1995
|
| 168 |
+
):
|
| 169 |
+
"""Generate image and return previous + current for comparison"""
|
| 170 |
+
if not prompt.strip():
|
| 171 |
+
return self.previous_image, None, self.previous_prompt or ""
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
generator = torch.Generator(device=self.device).manual_seed(int(seed))
|
| 175 |
+
|
| 176 |
+
current_image = self.pipe(
|
| 177 |
+
prompt=prompt,
|
| 178 |
+
guidance_scale=guidance_scale,
|
| 179 |
+
true_cfg_scale=true_cf_scale,
|
| 180 |
+
max_sequence_length=512,
|
| 181 |
+
num_inference_steps=num_inference_steps,
|
| 182 |
+
height=height,
|
| 183 |
+
width=width,
|
| 184 |
+
generator=generator,
|
| 185 |
+
).images[0]
|
| 186 |
+
|
| 187 |
+
prev_image = self.previous_image
|
| 188 |
+
prev_prompt = self.previous_prompt or "No previous generation"
|
| 189 |
+
|
| 190 |
+
self.previous_image = current_image
|
| 191 |
+
self.previous_prompt = prompt
|
| 192 |
+
|
| 193 |
+
return prev_image, current_image, prev_prompt
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
import traceback
|
| 197 |
+
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
|
| 198 |
+
print(error_msg)
|
| 199 |
+
return self.previous_image, None, self.previous_prompt or ""
|
| 200 |
+
|
| 201 |
+
def reset_history(self):
|
| 202 |
+
"""Clear generation history"""
|
| 203 |
+
self.previous_image = None
|
| 204 |
+
self.previous_prompt = None
|
| 205 |
+
return None, None, "No previous generation"
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# Initialize model
|
| 209 |
+
print("Initializing ConceptAligner model...")
|
| 210 |
+
model = ConceptAlignerModel()
|
empty_pooled_clip.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92acbe688a00c835deb9b645fe673e16af2ceef9cd749a8b838e67dea23d76b2
|
| 3 |
+
size 3183
|
pipeline.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import inspect
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import (
|
| 7 |
+
CLIPImageProcessor,
|
| 8 |
+
CLIPTextModel,
|
| 9 |
+
CLIPTokenizer,
|
| 10 |
+
CLIPVisionModelWithProjection,
|
| 11 |
+
T5EncoderModel,
|
| 12 |
+
T5TokenizerFast,
|
| 13 |
+
)
|
| 14 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 15 |
+
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 16 |
+
from diffusers.models import AutoencoderKL, FluxTransformer2DModel
|
| 17 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 18 |
+
from diffusers.utils import (
|
| 19 |
+
USE_PEFT_BACKEND,
|
| 20 |
+
deprecate,
|
| 21 |
+
is_torch_xla_available,
|
| 22 |
+
logging,
|
| 23 |
+
replace_example_docstring,
|
| 24 |
+
scale_lora_layers,
|
| 25 |
+
unscale_lora_layers,
|
| 26 |
+
)
|
| 27 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 28 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 29 |
+
from diffusers import FluxKontextPipeline
|
| 30 |
+
|
| 31 |
+
PREFERRED_KONTEXT_RESOLUTIONS = [
|
| 32 |
+
(672, 1568),
|
| 33 |
+
(688, 1504),
|
| 34 |
+
(720, 1456),
|
| 35 |
+
(752, 1392),
|
| 36 |
+
(800, 1328),
|
| 37 |
+
(832, 1248),
|
| 38 |
+
(880, 1184),
|
| 39 |
+
(944, 1104),
|
| 40 |
+
(1024, 1024),
|
| 41 |
+
(1104, 944),
|
| 42 |
+
(1184, 880),
|
| 43 |
+
(1248, 832),
|
| 44 |
+
(1328, 800),
|
| 45 |
+
(1392, 752),
|
| 46 |
+
(1456, 720),
|
| 47 |
+
(1504, 688),
|
| 48 |
+
(1568, 672),
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def calculate_shift(
|
| 53 |
+
image_seq_len,
|
| 54 |
+
base_seq_len: int = 256,
|
| 55 |
+
max_seq_len: int = 4096,
|
| 56 |
+
base_shift: float = 0.5,
|
| 57 |
+
max_shift: float = 1.15,
|
| 58 |
+
):
|
| 59 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 60 |
+
b = base_shift - m * base_seq_len
|
| 61 |
+
mu = image_seq_len * m + b
|
| 62 |
+
return mu
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 66 |
+
def retrieve_timesteps(
|
| 67 |
+
scheduler,
|
| 68 |
+
num_inference_steps: Optional[int] = None,
|
| 69 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 70 |
+
timesteps: Optional[List[int]] = None,
|
| 71 |
+
sigmas: Optional[List[float]] = None,
|
| 72 |
+
**kwargs,
|
| 73 |
+
):
|
| 74 |
+
r"""
|
| 75 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 76 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
scheduler (`SchedulerMixin`):
|
| 80 |
+
The scheduler to get timesteps from.
|
| 81 |
+
num_inference_steps (`int`):
|
| 82 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 83 |
+
must be `None`.
|
| 84 |
+
device (`str` or `torch.device`, *optional*):
|
| 85 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 86 |
+
timesteps (`List[int]`, *optional*):
|
| 87 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 88 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 89 |
+
sigmas (`List[float]`, *optional*):
|
| 90 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 91 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 95 |
+
second element is the number of inference steps.
|
| 96 |
+
"""
|
| 97 |
+
if timesteps is not None and sigmas is not None:
|
| 98 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 99 |
+
if timesteps is not None:
|
| 100 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 101 |
+
if not accepts_timesteps:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 104 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 105 |
+
)
|
| 106 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 107 |
+
timesteps = scheduler.timesteps
|
| 108 |
+
num_inference_steps = len(timesteps)
|
| 109 |
+
elif sigmas is not None:
|
| 110 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 111 |
+
if not accept_sigmas:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 114 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 115 |
+
)
|
| 116 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 117 |
+
timesteps = scheduler.timesteps
|
| 118 |
+
num_inference_steps = len(timesteps)
|
| 119 |
+
else:
|
| 120 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 121 |
+
timesteps = scheduler.timesteps
|
| 122 |
+
return timesteps, num_inference_steps
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 126 |
+
def retrieve_latents(
|
| 127 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 128 |
+
):
|
| 129 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 130 |
+
return encoder_output.latent_dist.sample(generator)
|
| 131 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 132 |
+
return encoder_output.latent_dist.mode()
|
| 133 |
+
elif hasattr(encoder_output, "latents"):
|
| 134 |
+
return encoder_output.latents
|
| 135 |
+
else:
|
| 136 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
from diffusers import FluxKontextPipeline
|
| 140 |
+
from typing import Union, List, Optional
|
| 141 |
+
import torch
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class CustomFluxKontextPipeline(FluxKontextPipeline):
|
| 145 |
+
r"""
|
| 146 |
+
Custom Flux Kontext pipeline with a wrapper text embedder.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
model_cpu_offload_seq = "text_embedder->image_encoder->transformer->vae"
|
| 150 |
+
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
scheduler,
|
| 154 |
+
vae,
|
| 155 |
+
text_embedder, # Your custom text embedder wrapper
|
| 156 |
+
transformer,
|
| 157 |
+
aligner,
|
| 158 |
+
image_encoder=None,
|
| 159 |
+
feature_extractor=None,
|
| 160 |
+
):
|
| 161 |
+
# Don't call super().__init__() since parent expects text_encoder parameters
|
| 162 |
+
# Instead, manually register modules
|
| 163 |
+
from diffusers import DiffusionPipeline
|
| 164 |
+
DiffusionPipeline.__init__(self)
|
| 165 |
+
|
| 166 |
+
self.register_modules(
|
| 167 |
+
vae=vae,
|
| 168 |
+
text_embedder=text_embedder,
|
| 169 |
+
transformer=transformer,
|
| 170 |
+
scheduler=scheduler,
|
| 171 |
+
aligner=aligner,
|
| 172 |
+
image_encoder=image_encoder,
|
| 173 |
+
feature_extractor=feature_extractor,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Initialize the necessary attributes from parent
|
| 177 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 178 |
+
self.latent_channels = self.vae.config.latent_channels
|
| 179 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 180 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 181 |
+
self.default_sample_size = 128
|
| 182 |
+
|
| 183 |
+
def encode_prompt(
|
| 184 |
+
self,
|
| 185 |
+
prompt: Union[str, List[str]],
|
| 186 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 187 |
+
device: Optional[torch.device] = None,
|
| 188 |
+
num_images_per_prompt: int = 1,
|
| 189 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 190 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 191 |
+
max_sequence_length: int = 512,
|
| 192 |
+
lora_scale: Optional[float] = None,
|
| 193 |
+
temperature=None,
|
| 194 |
+
threshold=None,
|
| 195 |
+
):
|
| 196 |
+
device = device or self._execution_device
|
| 197 |
+
|
| 198 |
+
if prompt_embeds is None:
|
| 199 |
+
# Use your custom text embedder
|
| 200 |
+
qwen_embeds, clip_image_embeds, perturbed_qwen_embeds, replace_ids, t5_tokenizer, batch_encoding = self.text_embedder(prompt)
|
| 201 |
+
prompt_embeds, prompt_attention_mask, pooled_prompt_embeds, text_ids, _ = self.aligner(qwen_embeds,
|
| 202 |
+
)
|
| 203 |
+
prompt_embeds = prompt_embeds.to(device=device)
|
| 204 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(device=device)
|
| 205 |
+
text_ids = text_ids.to(device=device)
|
| 206 |
+
else:
|
| 207 |
+
# When embeddings are provided, create text_ids
|
| 208 |
+
dtype = self.transformer.dtype
|
| 209 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 210 |
+
|
| 211 |
+
# Duplicate for num_images_per_prompt
|
| 212 |
+
if num_images_per_prompt > 1:
|
| 213 |
+
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 214 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 215 |
+
text_ids = text_ids.repeat(num_images_per_prompt, 1)
|
| 216 |
+
|
| 217 |
+
return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds, text_ids
|
| 218 |
+
|
| 219 |
+
@torch.no_grad()
|
| 220 |
+
def __call__(
|
| 221 |
+
self,
|
| 222 |
+
image: Optional[PipelineImageInput] = None,
|
| 223 |
+
prompt: Union[str, List[str]] = None,
|
| 224 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 225 |
+
negative_prompt: Union[str, List[str]] = "",
|
| 226 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 227 |
+
true_cfg_scale: float = 1.0,
|
| 228 |
+
height: Optional[int] = None,
|
| 229 |
+
width: Optional[int] = None,
|
| 230 |
+
num_inference_steps: int = 28,
|
| 231 |
+
sigmas: Optional[List[float]] = None,
|
| 232 |
+
guidance_scale: float = 3.5,
|
| 233 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 234 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 235 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 236 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 237 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 238 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 239 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 240 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 241 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 242 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 243 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 244 |
+
output_type: Optional[str] = "pil",
|
| 245 |
+
return_dict: bool = True,
|
| 246 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 247 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 248 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 249 |
+
max_sequence_length: int = 512,
|
| 250 |
+
max_area: int = 1024 ** 2,
|
| 251 |
+
_auto_resize: bool = True,
|
| 252 |
+
temperature=None,
|
| 253 |
+
threshold=None,
|
| 254 |
+
):
|
| 255 |
+
r"""
|
| 256 |
+
Function invoked when calling the pipeline for generation.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 260 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 261 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 262 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 263 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 264 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 265 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 266 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 267 |
+
instead.
|
| 268 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 269 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 270 |
+
will be used instead.
|
| 271 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 272 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 273 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 274 |
+
not greater than `1`).
|
| 275 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 276 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 277 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 278 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 279 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
| 280 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 281 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 282 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 283 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 284 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 285 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 286 |
+
expense of slower inference.
|
| 287 |
+
sigmas (`List[float]`, *optional*):
|
| 288 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 289 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 290 |
+
will be used.
|
| 291 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 292 |
+
Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
| 293 |
+
a model to generate images more aligned with prompt at the expense of lower image quality.
|
| 294 |
+
|
| 295 |
+
Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
|
| 296 |
+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
| 297 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 298 |
+
The number of images to generate per prompt.
|
| 299 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 300 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 301 |
+
to make generation deterministic.
|
| 302 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 303 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 304 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 305 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 306 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 307 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 308 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 309 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 310 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 311 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 312 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
| 313 |
+
Optional image input to work with IP Adapters.
|
| 314 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 315 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 316 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 317 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 318 |
+
negative_ip_adapter_image:
|
| 319 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 320 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 321 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 322 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 323 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 324 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 325 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 326 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 327 |
+
argument.
|
| 328 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 329 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 330 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 331 |
+
input argument.
|
| 332 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 333 |
+
The output format of the generate image. Choose between
|
| 334 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 335 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 336 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 337 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 338 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 339 |
+
`self.processor` in
|
| 340 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 341 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 342 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 343 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 344 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 345 |
+
`callback_on_step_end_tensor_inputs`.
|
| 346 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 347 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 348 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 349 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 350 |
+
max_sequence_length (`int` defaults to 512):
|
| 351 |
+
Maximum sequence length to use with the `prompt`.
|
| 352 |
+
max_area (`int`, defaults to `1024 ** 2`):
|
| 353 |
+
The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
|
| 354 |
+
area while maintaining the aspect ratio.
|
| 355 |
+
|
| 356 |
+
Examples:
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 360 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 361 |
+
images.
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 365 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 366 |
+
|
| 367 |
+
original_height, original_width = height, width
|
| 368 |
+
aspect_ratio = width / height
|
| 369 |
+
|
| 370 |
+
"""
|
| 371 |
+
width = round((max_area * aspect_ratio) ** 0.5)
|
| 372 |
+
height = round((max_area / aspect_ratio) ** 0.5)
|
| 373 |
+
multiple_of = self.vae_scale_factor * 2
|
| 374 |
+
width = width // multiple_of * multiple_of
|
| 375 |
+
height = height // multiple_of * multiple_of
|
| 376 |
+
|
| 377 |
+
if height != original_height or width != original_width:
|
| 378 |
+
print(
|
| 379 |
+
f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
|
| 380 |
+
)
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
# 1. Check inputs. Raise error if not correct
|
| 384 |
+
self.check_inputs(
|
| 385 |
+
prompt,
|
| 386 |
+
prompt_2,
|
| 387 |
+
height,
|
| 388 |
+
width,
|
| 389 |
+
negative_prompt=negative_prompt,
|
| 390 |
+
negative_prompt_2=negative_prompt_2,
|
| 391 |
+
prompt_embeds=prompt_embeds,
|
| 392 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 393 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 394 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 395 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 396 |
+
max_sequence_length=max_sequence_length,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
self._guidance_scale = guidance_scale
|
| 400 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 401 |
+
self._current_timestep = None
|
| 402 |
+
self._interrupt = False
|
| 403 |
+
|
| 404 |
+
# 2. Define call parameters
|
| 405 |
+
if prompt is not None and isinstance(prompt, str):
|
| 406 |
+
batch_size = 1
|
| 407 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 408 |
+
batch_size = len(prompt)
|
| 409 |
+
else:
|
| 410 |
+
batch_size = prompt_embeds.shape[0]
|
| 411 |
+
|
| 412 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 413 |
+
|
| 414 |
+
lora_scale = (
|
| 415 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 416 |
+
)
|
| 417 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 418 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 419 |
+
)
|
| 420 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 421 |
+
(
|
| 422 |
+
prompt_embeds,
|
| 423 |
+
prompt_attention_mask,
|
| 424 |
+
pooled_prompt_embeds,
|
| 425 |
+
text_ids,
|
| 426 |
+
) = self.encode_prompt(
|
| 427 |
+
prompt=prompt,
|
| 428 |
+
prompt_2=prompt_2,
|
| 429 |
+
prompt_embeds=prompt_embeds,
|
| 430 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 431 |
+
device=device,
|
| 432 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 433 |
+
max_sequence_length=max_sequence_length,
|
| 434 |
+
lora_scale=lora_scale,
|
| 435 |
+
temperature=temperature,
|
| 436 |
+
threshold=threshold,
|
| 437 |
+
)
|
| 438 |
+
(
|
| 439 |
+
negative_prompt_embeds,
|
| 440 |
+
negative_prompt_attention_mask,
|
| 441 |
+
negative_pooled_prompt_embeds,
|
| 442 |
+
negative_text_ids,
|
| 443 |
+
) = self.encode_prompt(
|
| 444 |
+
prompt=negative_prompt,
|
| 445 |
+
prompt_2=negative_prompt_2,
|
| 446 |
+
prompt_embeds=negative_prompt_embeds,
|
| 447 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 448 |
+
device=device,
|
| 449 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 450 |
+
max_sequence_length=max_sequence_length,
|
| 451 |
+
lora_scale=lora_scale,
|
| 452 |
+
temperature=temperature,
|
| 453 |
+
threshold=threshold,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
pooled_prompt_embeds = negative_pooled_prompt_embeds
|
| 457 |
+
|
| 458 |
+
# 3. Preprocess image
|
| 459 |
+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
| 460 |
+
img = image[0] if isinstance(image, list) else image
|
| 461 |
+
"""
|
| 462 |
+
image_height, image_width = self.image_processor.get_default_height_width(img)
|
| 463 |
+
aspect_ratio = image_width / image_height
|
| 464 |
+
if _auto_resize:
|
| 465 |
+
# Kontext is trained on specific resolutions, using one of them is recommended
|
| 466 |
+
_, image_width, image_height = min(
|
| 467 |
+
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
| 468 |
+
)
|
| 469 |
+
image_width = image_width // multiple_of * multiple_of
|
| 470 |
+
image_height = image_height // multiple_of * multiple_of
|
| 471 |
+
"""
|
| 472 |
+
image_height, image_width = original_height, original_width
|
| 473 |
+
image = self.image_processor.resize(image, image_height, image_width)
|
| 474 |
+
image = self.image_processor.preprocess(image, image_height, image_width)
|
| 475 |
+
|
| 476 |
+
# 4. Prepare latent variables
|
| 477 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 478 |
+
latents, image_latents, latent_ids, image_ids = self.prepare_latents(
|
| 479 |
+
image,
|
| 480 |
+
batch_size * num_images_per_prompt,
|
| 481 |
+
num_channels_latents,
|
| 482 |
+
height,
|
| 483 |
+
width,
|
| 484 |
+
prompt_embeds.dtype,
|
| 485 |
+
device,
|
| 486 |
+
generator,
|
| 487 |
+
latents,
|
| 488 |
+
)
|
| 489 |
+
if image_ids is not None:
|
| 490 |
+
latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
|
| 491 |
+
|
| 492 |
+
# 5. Prepare timesteps
|
| 493 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 494 |
+
image_seq_len = latents.shape[1]
|
| 495 |
+
mu = calculate_shift(
|
| 496 |
+
image_seq_len,
|
| 497 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 498 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 499 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 500 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 501 |
+
)
|
| 502 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 503 |
+
self.scheduler,
|
| 504 |
+
num_inference_steps,
|
| 505 |
+
device,
|
| 506 |
+
sigmas=sigmas,
|
| 507 |
+
mu=mu,
|
| 508 |
+
)
|
| 509 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 510 |
+
self._num_timesteps = len(timesteps)
|
| 511 |
+
|
| 512 |
+
# handle guidance
|
| 513 |
+
if self.transformer.config.guidance_embeds:
|
| 514 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 515 |
+
guidance = guidance.expand(latents.shape[0])
|
| 516 |
+
else:
|
| 517 |
+
guidance = None
|
| 518 |
+
|
| 519 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 520 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 521 |
+
):
|
| 522 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 523 |
+
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 524 |
+
|
| 525 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 526 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 527 |
+
):
|
| 528 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 529 |
+
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
| 530 |
+
|
| 531 |
+
if self.joint_attention_kwargs is None:
|
| 532 |
+
self._joint_attention_kwargs = {}
|
| 533 |
+
|
| 534 |
+
image_embeds = None
|
| 535 |
+
negative_image_embeds = None
|
| 536 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 537 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 538 |
+
ip_adapter_image,
|
| 539 |
+
ip_adapter_image_embeds,
|
| 540 |
+
device,
|
| 541 |
+
batch_size * num_images_per_prompt,
|
| 542 |
+
)
|
| 543 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 544 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 545 |
+
negative_ip_adapter_image,
|
| 546 |
+
negative_ip_adapter_image_embeds,
|
| 547 |
+
device,
|
| 548 |
+
batch_size * num_images_per_prompt,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# 6. Denoising loop
|
| 552 |
+
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
| 553 |
+
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
| 554 |
+
self.scheduler.set_begin_index(0)
|
| 555 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 556 |
+
for i, t in enumerate(timesteps):
|
| 557 |
+
if self.interrupt:
|
| 558 |
+
continue
|
| 559 |
+
|
| 560 |
+
self._current_timestep = t
|
| 561 |
+
if image_embeds is not None:
|
| 562 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 563 |
+
|
| 564 |
+
latent_model_input = latents
|
| 565 |
+
if image_latents is not None:
|
| 566 |
+
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 567 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 568 |
+
|
| 569 |
+
noise_pred = self.transformer(
|
| 570 |
+
hidden_states=latent_model_input,
|
| 571 |
+
timestep=timestep / 1000,
|
| 572 |
+
guidance=guidance,
|
| 573 |
+
pooled_projections=pooled_prompt_embeds,
|
| 574 |
+
encoder_hidden_states=prompt_embeds,
|
| 575 |
+
txt_ids=text_ids,
|
| 576 |
+
img_ids=latent_ids,
|
| 577 |
+
joint_attention_kwargs={'attention_mask': prompt_attention_mask},
|
| 578 |
+
return_dict=False,
|
| 579 |
+
)[0]
|
| 580 |
+
noise_pred = noise_pred[:, : latents.size(1)]
|
| 581 |
+
|
| 582 |
+
if do_true_cfg:
|
| 583 |
+
if negative_image_embeds is not None:
|
| 584 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
| 585 |
+
neg_noise_pred = self.transformer(
|
| 586 |
+
hidden_states=latent_model_input,
|
| 587 |
+
timestep=timestep / 1000,
|
| 588 |
+
guidance=guidance,
|
| 589 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 590 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 591 |
+
txt_ids=negative_text_ids,
|
| 592 |
+
img_ids=latent_ids,
|
| 593 |
+
joint_attention_kwargs={'attention_mask': negative_prompt_attention_mask},
|
| 594 |
+
return_dict=False,
|
| 595 |
+
)[0]
|
| 596 |
+
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
| 597 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 598 |
+
|
| 599 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 600 |
+
latents_dtype = latents.dtype
|
| 601 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 602 |
+
|
| 603 |
+
if latents.dtype != latents_dtype:
|
| 604 |
+
if torch.backends.mps.is_available():
|
| 605 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 606 |
+
latents = latents.to(latents_dtype)
|
| 607 |
+
|
| 608 |
+
if callback_on_step_end is not None:
|
| 609 |
+
callback_kwargs = {}
|
| 610 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 611 |
+
callback_kwargs[k] = locals()[k]
|
| 612 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 613 |
+
|
| 614 |
+
latents = callback_outputs.pop("latents", latents)
|
| 615 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 616 |
+
|
| 617 |
+
# call the callback, if provided
|
| 618 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 619 |
+
progress_bar.update()
|
| 620 |
+
|
| 621 |
+
self._current_timestep = None
|
| 622 |
+
|
| 623 |
+
if output_type == "latent":
|
| 624 |
+
image = latents
|
| 625 |
+
else:
|
| 626 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 627 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 628 |
+
|
| 629 |
+
dtype = torch.bfloat16
|
| 630 |
+
image = self.vae.decode(latents.to(dtype), return_dict=False)[0]
|
| 631 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 632 |
+
|
| 633 |
+
# Offload all models
|
| 634 |
+
self.maybe_free_model_hooks()
|
| 635 |
+
|
| 636 |
+
if not return_dict:
|
| 637 |
+
return (image,)
|
| 638 |
+
|
| 639 |
+
return FluxPipelineOutput(images=image)
|
| 640 |
+
|
| 641 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
diffusers>=0.27.0
|
| 4 |
+
transformers>=4.38.0
|
| 5 |
+
safetensors>=0.4.0
|
| 6 |
+
accelerate>=0.26.0
|
| 7 |
+
peft>=0.8.0
|
| 8 |
+
Pillow>=10.0.0
|
requirements.txt.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
diffusers>=0.27.0
|
| 4 |
+
transformers>=4.38.0
|
| 5 |
+
safetensors>=0.4.0
|
| 6 |
+
accelerate>=0.26.0
|
| 7 |
+
peft>=0.8.0
|
| 8 |
+
Pillow>=10.0.0
|
text_encoder.py
ADDED
|
@@ -0,0 +1,1188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
|
| 8 |
+
text_inputs = tokenizer(
|
| 9 |
+
prompt,
|
| 10 |
+
padding="max_length",
|
| 11 |
+
max_length=max_sequence_length,
|
| 12 |
+
truncation=True,
|
| 13 |
+
return_length=False,
|
| 14 |
+
return_overflowing_tokens=False,
|
| 15 |
+
return_tensors="pt",
|
| 16 |
+
)
|
| 17 |
+
text_input_ids = text_inputs.input_ids
|
| 18 |
+
return text_input_ids
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _encode_prompt_with_t5(
|
| 22 |
+
text_encoder,
|
| 23 |
+
tokenizer,
|
| 24 |
+
max_sequence_length=512,
|
| 25 |
+
prompt=None,
|
| 26 |
+
num_images_per_prompt=1,
|
| 27 |
+
device=None,
|
| 28 |
+
text_input_ids=None,
|
| 29 |
+
):
|
| 30 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 31 |
+
batch_size = len(prompt)
|
| 32 |
+
|
| 33 |
+
if tokenizer is not None:
|
| 34 |
+
text_inputs = tokenizer(
|
| 35 |
+
prompt,
|
| 36 |
+
padding="max_length",
|
| 37 |
+
max_length=max_sequence_length,
|
| 38 |
+
truncation=True,
|
| 39 |
+
return_length=False,
|
| 40 |
+
return_overflowing_tokens=False,
|
| 41 |
+
return_tensors="pt",
|
| 42 |
+
)
|
| 43 |
+
text_input_ids = text_inputs.input_ids
|
| 44 |
+
else:
|
| 45 |
+
if text_input_ids is None:
|
| 46 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 47 |
+
|
| 48 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
| 49 |
+
|
| 50 |
+
if hasattr(text_encoder, "module"):
|
| 51 |
+
dtype = text_encoder.module.dtype
|
| 52 |
+
else:
|
| 53 |
+
dtype = text_encoder.dtype
|
| 54 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 55 |
+
|
| 56 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 57 |
+
|
| 58 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 59 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 60 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 61 |
+
|
| 62 |
+
return prompt_embeds
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _encode_prompt_with_clip(
|
| 66 |
+
text_encoder,
|
| 67 |
+
tokenizer,
|
| 68 |
+
prompt: str,
|
| 69 |
+
device=None,
|
| 70 |
+
text_input_ids=None,
|
| 71 |
+
num_images_per_prompt: int = 1,
|
| 72 |
+
):
|
| 73 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 74 |
+
batch_size = len(prompt)
|
| 75 |
+
|
| 76 |
+
if tokenizer is not None:
|
| 77 |
+
text_inputs = tokenizer(
|
| 78 |
+
prompt,
|
| 79 |
+
padding="max_length",
|
| 80 |
+
max_length=77,
|
| 81 |
+
truncation=True,
|
| 82 |
+
return_overflowing_tokens=False,
|
| 83 |
+
return_length=False,
|
| 84 |
+
return_tensors="pt",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
text_input_ids = text_inputs.input_ids
|
| 88 |
+
else:
|
| 89 |
+
if text_input_ids is None:
|
| 90 |
+
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
| 91 |
+
|
| 92 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 93 |
+
|
| 94 |
+
if hasattr(text_encoder, "module"):
|
| 95 |
+
dtype = text_encoder.module.dtype
|
| 96 |
+
else:
|
| 97 |
+
dtype = text_encoder.dtype
|
| 98 |
+
# Use pooled output of CLIPTextModel
|
| 99 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 100 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 101 |
+
|
| 102 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 103 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 104 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 105 |
+
|
| 106 |
+
return prompt_embeds
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def encode_prompt(
|
| 110 |
+
text_encoders,
|
| 111 |
+
tokenizers,
|
| 112 |
+
prompt: str,
|
| 113 |
+
max_sequence_length,
|
| 114 |
+
device=None,
|
| 115 |
+
num_images_per_prompt: int = 1,
|
| 116 |
+
text_input_ids_list=None,
|
| 117 |
+
):
|
| 118 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 119 |
+
|
| 120 |
+
if hasattr(text_encoders[0], "module"):
|
| 121 |
+
dtype = text_encoders[0].module.dtype
|
| 122 |
+
else:
|
| 123 |
+
dtype = text_encoders[0].dtype
|
| 124 |
+
|
| 125 |
+
pooled_prompt_embeds = _encode_prompt_with_clip(
|
| 126 |
+
text_encoder=text_encoders[0],
|
| 127 |
+
tokenizer=tokenizers[0],
|
| 128 |
+
prompt=prompt,
|
| 129 |
+
device=device if device is not None else text_encoders[0].device,
|
| 130 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 131 |
+
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
prompt_embeds = _encode_prompt_with_t5(
|
| 135 |
+
text_encoder=text_encoders[1],
|
| 136 |
+
tokenizer=tokenizers[1],
|
| 137 |
+
max_sequence_length=max_sequence_length,
|
| 138 |
+
prompt=prompt,
|
| 139 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 140 |
+
device=device if device is not None else text_encoders[1].device,
|
| 141 |
+
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 145 |
+
|
| 146 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel
|
| 150 |
+
class T5Embedder(torch.nn.Module):
|
| 151 |
+
def __init__(self, device, max_length=300):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.device = device
|
| 154 |
+
self.max_length = max_length
|
| 155 |
+
dtype = torch.bfloat16
|
| 156 |
+
self.dtype = dtype
|
| 157 |
+
t5_version = './t5-v1_1-xxl'
|
| 158 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
|
| 159 |
+
self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device)
|
| 160 |
+
self.t5_encoder = self.t5_encoder.eval().requires_grad_(False)
|
| 161 |
+
self.num_shared = max_length
|
| 162 |
+
|
| 163 |
+
@torch.no_grad()
|
| 164 |
+
def forward(self, text):
|
| 165 |
+
if isinstance(text, str):
|
| 166 |
+
text = [text]
|
| 167 |
+
batch_encoding = self.t5_tokenizer(
|
| 168 |
+
text,
|
| 169 |
+
truncation=True,
|
| 170 |
+
max_length=self.max_length,
|
| 171 |
+
return_length=False,
|
| 172 |
+
return_overflowing_tokens=False,
|
| 173 |
+
padding="max_length",
|
| 174 |
+
return_tensors="pt",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
prompt_embeds = self.t5_encoder(
|
| 178 |
+
input_ids=batch_encoding["input_ids"].to(self.device),
|
| 179 |
+
attention_mask=None,
|
| 180 |
+
output_hidden_states=False,
|
| 181 |
+
)['last_hidden_state']
|
| 182 |
+
prompt_attention_mask = batch_encoding['attention_mask'].to(self.device)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
new_text = [x.split('.')[0] for x in text]
|
| 186 |
+
batch_encoding = self.t5_tokenizer(
|
| 187 |
+
new_text,
|
| 188 |
+
truncation=True,
|
| 189 |
+
max_length=self.num_shared,
|
| 190 |
+
return_length=False,
|
| 191 |
+
return_overflowing_tokens=False,
|
| 192 |
+
padding="max_length",
|
| 193 |
+
return_tensors="pt",
|
| 194 |
+
)
|
| 195 |
+
shared_prompt_embeds = self.t5_encoder(
|
| 196 |
+
input_ids=batch_encoding["input_ids"].to(self.device),
|
| 197 |
+
attention_mask=None,
|
| 198 |
+
output_hidden_states=False,
|
| 199 |
+
)['last_hidden_state']
|
| 200 |
+
|
| 201 |
+
return prompt_embeds, shared_prompt_embeds, prompt_attention_mask
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
import random
|
| 207 |
+
|
| 208 |
+
from torch.utils.checkpoint import checkpoint
|
| 209 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
| 210 |
+
class LoraT5EmbedderNoGradientCheck(torch.nn.Module):
|
| 211 |
+
def __init__(self, device, rank=64, max_length=300):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.device = device
|
| 214 |
+
self.max_length = max_length
|
| 215 |
+
dtype = torch.bfloat16
|
| 216 |
+
self.dtype = dtype
|
| 217 |
+
t5_version = './t5-v1_1-xxl'
|
| 218 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
|
| 219 |
+
self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device).to(dtype)
|
| 220 |
+
self.t5_encoder.gradient_checkpointing_enable()
|
| 221 |
+
self.t5_encoder.config.gradient_checkpointing = True
|
| 222 |
+
self.t5_encoder.requires_grad_(False)
|
| 223 |
+
self.t5_encoder.eval()
|
| 224 |
+
# Add LoRA adapters to the T5 model
|
| 225 |
+
text_lora_config = LoraConfig(
|
| 226 |
+
r=rank,
|
| 227 |
+
lora_alpha=rank,
|
| 228 |
+
lora_dropout=0.0,
|
| 229 |
+
init_lora_weights="gaussian",
|
| 230 |
+
target_modules=["SelfAttention.q", "SelfAttention.k", "SelfAttention.v", "SelfAttention.o", "DenseReluDense.wi", "DenseReluDense.wo"],
|
| 231 |
+
)
|
| 232 |
+
self.t5_encoder.add_adapter(text_lora_config)
|
| 233 |
+
#self.t5_encoder.encoder.embed_tokens.weight.requires_grad = True
|
| 234 |
+
print(f"Gradient checkpointing enabled: {self.t5_encoder.is_gradient_checkpointing}")
|
| 235 |
+
|
| 236 |
+
image_encoder_path = 'openai/clip-vit-large-patch14'
|
| 237 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device=device).to(torch.bfloat16)
|
| 238 |
+
self.image_encoder = self.image_encoder.eval().requires_grad_(False)
|
| 239 |
+
|
| 240 |
+
def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding):
|
| 241 |
+
"""
|
| 242 |
+
Compute group lasso for non-pad non-change tokens, L1 for change tokens,
|
| 243 |
+
and group sparsity for pad non-change tokens.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim]
|
| 247 |
+
perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim]
|
| 248 |
+
replaced_ids: List of replaced token indices for each sample in batch
|
| 249 |
+
batch_encoding: The tokenizer output containing input_ids
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor)
|
| 253 |
+
l1_loss: L1 loss for change tokens (scalar tensor)
|
| 254 |
+
pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor)
|
| 255 |
+
"""
|
| 256 |
+
batch_size = prompt_embeds.size(0)
|
| 257 |
+
pad_token_id = self.t5_tokenizer.pad_token_id
|
| 258 |
+
input_ids = batch_encoding["input_ids"]
|
| 259 |
+
|
| 260 |
+
l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
|
| 261 |
+
l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
|
| 262 |
+
pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
|
| 263 |
+
|
| 264 |
+
# Track valid samples for each loss type separately
|
| 265 |
+
l1_valid_samples = 0
|
| 266 |
+
l2_valid_samples = 0
|
| 267 |
+
pad_valid_samples = 0
|
| 268 |
+
|
| 269 |
+
for i in range(batch_size):
|
| 270 |
+
# Get the replaced index for this sample
|
| 271 |
+
replaced_idx = replaced_ids[i]
|
| 272 |
+
|
| 273 |
+
if replaced_idx is None:
|
| 274 |
+
# No replacement happened (all padding), skip
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
# Find padding and non-padding token indices
|
| 278 |
+
pad_mask = input_ids[i] == pad_token_id
|
| 279 |
+
non_pad_mask = ~pad_mask
|
| 280 |
+
|
| 281 |
+
pad_indices = torch.where(pad_mask)[0]
|
| 282 |
+
non_pad_indices = torch.where(non_pad_mask)[0]
|
| 283 |
+
|
| 284 |
+
# Filter out the replaced index from non-padding indices (non-pad non-change)
|
| 285 |
+
non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx]
|
| 286 |
+
|
| 287 |
+
# Compute L1 loss on selected (replaced) index - CHANGE TOKEN
|
| 288 |
+
selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx]
|
| 289 |
+
l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean()
|
| 290 |
+
l1_valid_samples += 1
|
| 291 |
+
|
| 292 |
+
# Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens
|
| 293 |
+
if len(non_selected_non_pad_indices) > 0:
|
| 294 |
+
non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[
|
| 295 |
+
i, non_selected_non_pad_indices]
|
| 296 |
+
l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1))
|
| 297 |
+
l2_loss_total = l2_loss_total + l2_per_token.mean()
|
| 298 |
+
l2_valid_samples += 1
|
| 299 |
+
|
| 300 |
+
# Compute group sparsity loss on PAD NON-CHANGE tokens
|
| 301 |
+
if len(pad_indices) > 0:
|
| 302 |
+
pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices]
|
| 303 |
+
# Group sparsity: L2 norm per token (encourages entire token embeddings to be zero)
|
| 304 |
+
pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1))
|
| 305 |
+
pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean()
|
| 306 |
+
pad_valid_samples += 1
|
| 307 |
+
|
| 308 |
+
# Average over valid samples for each loss type
|
| 309 |
+
l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0,
|
| 310 |
+
device=prompt_embeds.device)
|
| 311 |
+
l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0,
|
| 312 |
+
device=prompt_embeds.device)
|
| 313 |
+
pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0,
|
| 314 |
+
device=prompt_embeds.device)
|
| 315 |
+
|
| 316 |
+
return l2_loss, l1_loss, pad_group_loss
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def forward(self, text, image=None):
|
| 322 |
+
if isinstance(text, str):
|
| 323 |
+
text = [text]
|
| 324 |
+
batch_encoding = self.t5_tokenizer(
|
| 325 |
+
text,
|
| 326 |
+
truncation=True,
|
| 327 |
+
max_length=self.max_length,
|
| 328 |
+
return_length=False,
|
| 329 |
+
return_overflowing_tokens=False,
|
| 330 |
+
padding="max_length",
|
| 331 |
+
return_tensors="pt",
|
| 332 |
+
)
|
| 333 |
+
prompt_embeds = self.t5_encoder(
|
| 334 |
+
input_ids=batch_encoding["input_ids"].to(self.device),
|
| 335 |
+
attention_mask=None,
|
| 336 |
+
output_hidden_states=False,
|
| 337 |
+
)['last_hidden_state']
|
| 338 |
+
|
| 339 |
+
# Get input_ids and create a copy to modify
|
| 340 |
+
input_ids = batch_encoding["input_ids"].clone()
|
| 341 |
+
batch_size = input_ids.size(0)
|
| 342 |
+
|
| 343 |
+
# Get the padding token id
|
| 344 |
+
pad_token_id = self.t5_tokenizer.pad_token_id
|
| 345 |
+
|
| 346 |
+
replaced_ids = []
|
| 347 |
+
# For each sample in the batch
|
| 348 |
+
for i in range(batch_size):
|
| 349 |
+
# Find indices of non-padding tokens
|
| 350 |
+
non_pad_mask = input_ids[i] != pad_token_id
|
| 351 |
+
non_pad_indices = torch.where(non_pad_mask)[0]
|
| 352 |
+
|
| 353 |
+
# If there are meaningful tokens, randomly select one to replace
|
| 354 |
+
if len(non_pad_indices) > 0:
|
| 355 |
+
# Randomly select an index from non-padding tokens
|
| 356 |
+
random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
|
| 357 |
+
# Replace with padding token
|
| 358 |
+
input_ids[i, random_idx] = pad_token_id
|
| 359 |
+
replaced_ids.append(random_idx.item())
|
| 360 |
+
else:
|
| 361 |
+
replaced_ids.append(None) # No replacement if all tokens are padding
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
perturbed_prompt_embeds = self.t5_encoder(
|
| 365 |
+
input_ids=input_ids.to(self.device),
|
| 366 |
+
attention_mask=None,
|
| 367 |
+
output_hidden_states=False,
|
| 368 |
+
)['last_hidden_state']
|
| 369 |
+
|
| 370 |
+
l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss(
|
| 371 |
+
prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
with torch.no_grad():
|
| 375 |
+
if image is not None:
|
| 376 |
+
clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds
|
| 377 |
+
else:
|
| 378 |
+
clip_image_embeds = None
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
return prompt_embeds, l2_loss, l1_loss, pad_loss,clip_image_embeds
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
| 385 |
+
import torch.utils.checkpoint as checkpoint
|
| 386 |
+
from transformers import CLIPVisionModelWithProjection
|
| 387 |
+
|
| 388 |
+
class LoraT5Embedder(torch.nn.Module):
|
| 389 |
+
def __init__(self, device, rank=128, max_length=300, use_gradient_checkpointing=True):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.device = device
|
| 392 |
+
self.max_length = max_length
|
| 393 |
+
self.use_gradient_checkpointing = use_gradient_checkpointing
|
| 394 |
+
dtype = torch.bfloat16
|
| 395 |
+
self.dtype = dtype
|
| 396 |
+
t5_version = './t5-v1_1-xxl'
|
| 397 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length)
|
| 398 |
+
|
| 399 |
+
self.t5_encoder = T5EncoderModel.from_pretrained(
|
| 400 |
+
t5_version,
|
| 401 |
+
torch_dtype=dtype
|
| 402 |
+
).to(device=device).to(dtype)
|
| 403 |
+
|
| 404 |
+
self.t5_encoder.requires_grad_(False)
|
| 405 |
+
|
| 406 |
+
# Add LoRA adapters to the T5 model
|
| 407 |
+
text_lora_config = LoraConfig(
|
| 408 |
+
r=rank,
|
| 409 |
+
lora_alpha=rank,
|
| 410 |
+
lora_dropout=0.0,
|
| 411 |
+
init_lora_weights="gaussian",
|
| 412 |
+
target_modules=["q", "k", "v", "o", "wi", "wo"],
|
| 413 |
+
)
|
| 414 |
+
self.t5_encoder.add_adapter(text_lora_config)
|
| 415 |
+
self.t5_encoder.encoder.embed_tokens.weight.requires_grad_(True)
|
| 416 |
+
|
| 417 |
+
# Manually implement gradient checkpointing for T5 encoder blocks
|
| 418 |
+
if self.use_gradient_checkpointing:
|
| 419 |
+
self._enable_gradient_checkpointing()
|
| 420 |
+
|
| 421 |
+
print(f"Gradient checkpointing enabled: {self.use_gradient_checkpointing}")
|
| 422 |
+
|
| 423 |
+
image_encoder_path = './clip-vit-large-patch14'
|
| 424 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 425 |
+
image_encoder_path
|
| 426 |
+
).to(device=device).to(torch.bfloat16)
|
| 427 |
+
self.image_encoder = self.image_encoder.eval().requires_grad_(False)
|
| 428 |
+
|
| 429 |
+
def _enable_gradient_checkpointing(self):
|
| 430 |
+
"""
|
| 431 |
+
Manually wrap T5 encoder blocks with gradient checkpointing.
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
def create_custom_forward(module):
|
| 435 |
+
def custom_forward(*inputs):
|
| 436 |
+
return module(*inputs)
|
| 437 |
+
|
| 438 |
+
return custom_forward
|
| 439 |
+
|
| 440 |
+
# Wrap each T5 block with checkpointing
|
| 441 |
+
for block in self.t5_encoder.encoder.block:
|
| 442 |
+
# Store original forward
|
| 443 |
+
block._original_forward = block.forward
|
| 444 |
+
|
| 445 |
+
# Create checkpointed forward
|
| 446 |
+
def make_checkpointed_forward(blk):
|
| 447 |
+
def checkpointed_forward(*args, **kwargs):
|
| 448 |
+
# Checkpoint requires a function that takes tensors as input
|
| 449 |
+
def forward_wrapper(*inputs):
|
| 450 |
+
# Reconstruct kwargs from inputs
|
| 451 |
+
hidden_states = inputs[0]
|
| 452 |
+
attention_mask = inputs[1] if len(inputs) > 1 else None
|
| 453 |
+
position_bias = inputs[2] if len(inputs) > 2 else None
|
| 454 |
+
|
| 455 |
+
return blk._original_forward(
|
| 456 |
+
hidden_states=hidden_states,
|
| 457 |
+
attention_mask=attention_mask,
|
| 458 |
+
position_bias=position_bias,
|
| 459 |
+
**{k: v for k, v in kwargs.items() if
|
| 460 |
+
k not in ['hidden_states', 'attention_mask', 'position_bias']}
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# Prepare inputs for checkpointing
|
| 464 |
+
hidden_states = kwargs.get('hidden_states', args[0] if args else None)
|
| 465 |
+
attention_mask = kwargs.get('attention_mask', args[1] if len(args) > 1 else None)
|
| 466 |
+
position_bias = kwargs.get('position_bias', args[2] if len(args) > 2 else None)
|
| 467 |
+
|
| 468 |
+
# Use checkpoint
|
| 469 |
+
checkpoint_inputs = [hidden_states]
|
| 470 |
+
if attention_mask is not None:
|
| 471 |
+
checkpoint_inputs.append(attention_mask)
|
| 472 |
+
if position_bias is not None:
|
| 473 |
+
checkpoint_inputs.append(position_bias)
|
| 474 |
+
|
| 475 |
+
return checkpoint.checkpoint(
|
| 476 |
+
forward_wrapper,
|
| 477 |
+
*checkpoint_inputs,
|
| 478 |
+
use_reentrant=False
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
return checkpointed_forward
|
| 482 |
+
|
| 483 |
+
block.forward = make_checkpointed_forward(block)
|
| 484 |
+
|
| 485 |
+
def _encode_text(self, input_ids):
|
| 486 |
+
"""Helper function to encode text through T5."""
|
| 487 |
+
return self.t5_encoder(
|
| 488 |
+
input_ids=input_ids.to(self.device),
|
| 489 |
+
attention_mask=None,
|
| 490 |
+
output_hidden_states=False,
|
| 491 |
+
)['last_hidden_state']
|
| 492 |
+
|
| 493 |
+
def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding):
|
| 494 |
+
"""
|
| 495 |
+
Compute group lasso for non-pad non-change tokens, L1 for change tokens,
|
| 496 |
+
and group sparsity for pad non-change tokens.
|
| 497 |
+
|
| 498 |
+
Args:
|
| 499 |
+
prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim]
|
| 500 |
+
perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim]
|
| 501 |
+
replaced_ids: List of replaced token indices for each sample in batch
|
| 502 |
+
batch_encoding: The tokenizer output containing input_ids
|
| 503 |
+
|
| 504 |
+
Returns:
|
| 505 |
+
l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor)
|
| 506 |
+
l1_loss: L1 loss for change tokens (scalar tensor)
|
| 507 |
+
pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor)
|
| 508 |
+
"""
|
| 509 |
+
batch_size = prompt_embeds.size(0)
|
| 510 |
+
pad_token_id = self.t5_tokenizer.pad_token_id
|
| 511 |
+
input_ids = batch_encoding["input_ids"]
|
| 512 |
+
|
| 513 |
+
l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
|
| 514 |
+
l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
|
| 515 |
+
pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device)
|
| 516 |
+
|
| 517 |
+
# Track valid samples for each loss type separately
|
| 518 |
+
l1_valid_samples = 0
|
| 519 |
+
l2_valid_samples = 0
|
| 520 |
+
pad_valid_samples = 0
|
| 521 |
+
|
| 522 |
+
for i in range(batch_size):
|
| 523 |
+
# Get the replaced index for this sample
|
| 524 |
+
replaced_idx = replaced_ids[i]
|
| 525 |
+
|
| 526 |
+
if replaced_idx is None:
|
| 527 |
+
# No replacement happened (all padding), skip
|
| 528 |
+
continue
|
| 529 |
+
|
| 530 |
+
# Find padding and non-padding token indices
|
| 531 |
+
pad_mask = input_ids[i] == pad_token_id
|
| 532 |
+
non_pad_mask = ~pad_mask
|
| 533 |
+
|
| 534 |
+
pad_indices = torch.where(pad_mask)[0]
|
| 535 |
+
non_pad_indices = torch.where(non_pad_mask)[0]
|
| 536 |
+
|
| 537 |
+
# Filter out the replaced index from non-padding indices (non-pad non-change)
|
| 538 |
+
non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx]
|
| 539 |
+
|
| 540 |
+
# Compute L1 loss on selected (replaced) index - CHANGE TOKEN
|
| 541 |
+
selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx]
|
| 542 |
+
l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean()
|
| 543 |
+
l1_valid_samples += 1
|
| 544 |
+
|
| 545 |
+
# Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens
|
| 546 |
+
if len(non_selected_non_pad_indices) > 0:
|
| 547 |
+
non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[
|
| 548 |
+
i, non_selected_non_pad_indices]
|
| 549 |
+
l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1))
|
| 550 |
+
l2_loss_total = l2_loss_total + l2_per_token.mean()
|
| 551 |
+
l2_valid_samples += 1
|
| 552 |
+
|
| 553 |
+
# Compute group sparsity loss on PAD NON-CHANGE tokens
|
| 554 |
+
if len(pad_indices) > 0:
|
| 555 |
+
pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices]
|
| 556 |
+
# Group sparsity: L2 norm per token (encourages entire token embeddings to be zero)
|
| 557 |
+
pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1))
|
| 558 |
+
pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean()
|
| 559 |
+
pad_valid_samples += 1
|
| 560 |
+
|
| 561 |
+
# Average over valid samples for each loss type
|
| 562 |
+
l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0,
|
| 563 |
+
device=prompt_embeds.device)
|
| 564 |
+
l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0,
|
| 565 |
+
device=prompt_embeds.device)
|
| 566 |
+
pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0,
|
| 567 |
+
device=prompt_embeds.device)
|
| 568 |
+
|
| 569 |
+
return l2_loss, l1_loss, pad_group_loss
|
| 570 |
+
|
| 571 |
+
def forward(self, text, image=None):
|
| 572 |
+
if isinstance(text, str):
|
| 573 |
+
text = [text]
|
| 574 |
+
batch_encoding = self.t5_tokenizer(
|
| 575 |
+
text,
|
| 576 |
+
truncation=True,
|
| 577 |
+
max_length=self.max_length,
|
| 578 |
+
return_length=False,
|
| 579 |
+
return_overflowing_tokens=False,
|
| 580 |
+
padding="max_length",
|
| 581 |
+
return_tensors="pt",
|
| 582 |
+
)
|
| 583 |
+
attn_mask = batch_encoding["attention_mask"].to(self.device)
|
| 584 |
+
|
| 585 |
+
# First encoding
|
| 586 |
+
prompt_embeds = self._encode_text(batch_encoding["input_ids"])
|
| 587 |
+
|
| 588 |
+
# Get input_ids and create a copy to modify
|
| 589 |
+
input_ids = batch_encoding["input_ids"].clone()
|
| 590 |
+
batch_size = input_ids.size(0)
|
| 591 |
+
|
| 592 |
+
# Get the padding token id
|
| 593 |
+
# get the id for the first sentinel token
|
| 594 |
+
mask_token = "<extra_id_0>"
|
| 595 |
+
mask_token_id = self.t5_tokenizer.convert_tokens_to_ids(mask_token)
|
| 596 |
+
pad_token_id = self.t5_tokenizer.pad_token_id
|
| 597 |
+
|
| 598 |
+
replaced_ids = []
|
| 599 |
+
# For each sample in the batch
|
| 600 |
+
for i in range(batch_size):
|
| 601 |
+
# Find indices of non-padding tokens
|
| 602 |
+
non_pad_mask = input_ids[i] != pad_token_id
|
| 603 |
+
non_pad_indices = torch.where(non_pad_mask)[0]
|
| 604 |
+
|
| 605 |
+
# If there are meaningful tokens, randomly select one to replace
|
| 606 |
+
if len(non_pad_indices) > 0:
|
| 607 |
+
# Randomly select an index from non-padding tokens
|
| 608 |
+
random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
|
| 609 |
+
random_idx2 = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)]
|
| 610 |
+
# Replace with padding token
|
| 611 |
+
input_ids[i, random_idx] = mask_token_id
|
| 612 |
+
replaced_ids.append(random_idx.item())
|
| 613 |
+
else:
|
| 614 |
+
replaced_ids.append(None) # No replacement if all tokens are padding
|
| 615 |
+
|
| 616 |
+
# Second encoding with perturbed input
|
| 617 |
+
perturbed_prompt_embeds = self._encode_text(input_ids)
|
| 618 |
+
|
| 619 |
+
"""
|
| 620 |
+
l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss(
|
| 621 |
+
prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding
|
| 622 |
+
)
|
| 623 |
+
"""
|
| 624 |
+
|
| 625 |
+
with torch.no_grad():
|
| 626 |
+
if image is not None:
|
| 627 |
+
clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds
|
| 628 |
+
else:
|
| 629 |
+
clip_image_embeds = None
|
| 630 |
+
|
| 631 |
+
#return prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask
|
| 632 |
+
return prompt_embeds, clip_image_embeds, perturbed_prompt_embeds, replaced_ids, self.t5_tokenizer, batch_encoding
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
| 639 |
+
class QwenEmbedder(nn.Module):
|
| 640 |
+
def __init__(self, device, max_length=512):
|
| 641 |
+
super().__init__()
|
| 642 |
+
self.device = device
|
| 643 |
+
self.max_length = max_length
|
| 644 |
+
dtype = torch.bfloat16
|
| 645 |
+
self.dtype = dtype
|
| 646 |
+
self.tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
|
| 647 |
+
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 648 |
+
"Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=dtype,
|
| 649 |
+
).to(device=device)
|
| 650 |
+
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
| 651 |
+
self.prompt_template_encode_start_idx = 34
|
| 652 |
+
self.tokenizer_max_length = max_length
|
| 653 |
+
|
| 654 |
+
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
| 655 |
+
bool_mask = mask.bool()
|
| 656 |
+
valid_lengths = bool_mask.sum(dim=1)
|
| 657 |
+
selected = hidden_states[bool_mask]
|
| 658 |
+
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
| 659 |
+
|
| 660 |
+
return split_result
|
| 661 |
+
|
| 662 |
+
def _get_qwen_prompt_embeds(
|
| 663 |
+
self,
|
| 664 |
+
prompt = None,
|
| 665 |
+
device = None,
|
| 666 |
+
dtype = None,
|
| 667 |
+
):
|
| 668 |
+
device = device or self._execution_device
|
| 669 |
+
dtype = dtype or self.text_encoder.dtype
|
| 670 |
+
|
| 671 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 672 |
+
|
| 673 |
+
template = self.prompt_template_encode
|
| 674 |
+
drop_idx = self.prompt_template_encode_start_idx
|
| 675 |
+
txt = [template.format(e) for e in prompt]
|
| 676 |
+
txt_tokens = self.tokenizer(
|
| 677 |
+
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
| 678 |
+
).to(device)
|
| 679 |
+
encoder_hidden_states = self.text_encoder(
|
| 680 |
+
input_ids=txt_tokens.input_ids,
|
| 681 |
+
attention_mask=txt_tokens.attention_mask,
|
| 682 |
+
output_hidden_states=True,
|
| 683 |
+
)
|
| 684 |
+
hidden_states = encoder_hidden_states.hidden_states[-1]
|
| 685 |
+
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
| 686 |
+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
| 687 |
+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
| 688 |
+
#max_seq_len = max([e.size(0) for e in split_hidden_states])
|
| 689 |
+
max_seq_len = self.max_length
|
| 690 |
+
prompt_embeds = torch.stack(
|
| 691 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
| 692 |
+
)
|
| 693 |
+
encoder_attention_mask = torch.stack(
|
| 694 |
+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 698 |
+
|
| 699 |
+
return prompt_embeds, encoder_attention_mask
|
| 700 |
+
|
| 701 |
+
@torch.no_grad()
|
| 702 |
+
def forward(self, text):
|
| 703 |
+
prompt_embeds, attention_mask = self._get_qwen_prompt_embeds(
|
| 704 |
+
prompt=text,
|
| 705 |
+
device=self.device,
|
| 706 |
+
dtype=self.dtype,
|
| 707 |
+
)
|
| 708 |
+
return prompt_embeds, attention_mask
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
# pip install accelerate
|
| 713 |
+
|
| 714 |
+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
| 715 |
+
from PIL import Image
|
| 716 |
+
import requests
|
| 717 |
+
import torch
|
| 718 |
+
import torch.nn as nn
|
| 719 |
+
Qwen25VL_7b_PREFIX_edit = '''Given an user editing prompt and an source image, only describe the editing area and how they should change in a detailed way.
|
| 720 |
+
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
| 721 |
+
'''
|
| 722 |
+
|
| 723 |
+
Qwen25VL_7b_PREFIX_t2i = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
| 724 |
+
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
| 725 |
+
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
| 726 |
+
Here are examples of how to transform or refine prompts:
|
| 727 |
+
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
| 728 |
+
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
| 729 |
+
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
| 730 |
+
User Prompt:'''
|
| 731 |
+
Qwen25VL_7b_PREFIX_image = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
|
| 732 |
+
model_id = "google/gemma-3-4b-it"
|
| 733 |
+
|
| 734 |
+
from transformers import AutoTokenizer, TrainingArguments, Gemma3ForCausalLM, AutoModel, Gemma3Model
|
| 735 |
+
from transformers import Dinov2Model, AutoImageProcessor
|
| 736 |
+
|
| 737 |
+
import torch
|
| 738 |
+
import torchvision.transforms as transforms
|
| 739 |
+
import torchvision.models as models
|
| 740 |
+
from PIL import Image
|
| 741 |
+
import numpy as np
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
class GemmaEmbedder(nn.Module):
|
| 746 |
+
def __init__(self, max_sequence_length=300, model_id='google/gemma-3-4b-it'):
|
| 747 |
+
super().__init__()
|
| 748 |
+
device = torch.cuda.current_device()
|
| 749 |
+
self.model = Gemma3Model.from_pretrained(model_id).to(device).to(torch.bfloat16)
|
| 750 |
+
#self.model = Gemma3ForConditionalGeneration.from_pretrained(model_id).to(device).to(torch.bfloat16)
|
| 751 |
+
self.processor = AutoProcessor.from_pretrained(model_id)
|
| 752 |
+
self.device = device
|
| 753 |
+
self.max_sequence_length = max_sequence_length
|
| 754 |
+
#self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token # Use eos token as pad token
|
| 755 |
+
self.processor.tokenizer.padding_side = "right"
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def get_features(self, hidden_states, input_ids):
|
| 760 |
+
hidden_states = hidden_states[0]
|
| 761 |
+
input_ids = input_ids[0].tolist()
|
| 762 |
+
|
| 763 |
+
pad_text_embeds = torch.zeros([self.max_sequence_length, 2560], dtype=torch.bfloat16, device=self.device)
|
| 764 |
+
pad_text_mask = torch.zeros([self.max_sequence_length], device=self.device)
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def find_last(lst, value):
|
| 768 |
+
indices = [i for i, x in enumerate(lst) if x == value]
|
| 769 |
+
return indices[-1]
|
| 770 |
+
|
| 771 |
+
if 256000 in input_ids:
|
| 772 |
+
text_start = input_ids.index(256000)+2
|
| 773 |
+
else:
|
| 774 |
+
text_start = find_last(input_ids, 108)+1
|
| 775 |
+
|
| 776 |
+
text_end = len(input_ids)-6
|
| 777 |
+
bos_embed = hidden_states[:2]
|
| 778 |
+
text_embeds = hidden_states[text_start:text_end + 1]
|
| 779 |
+
text_embeds = torch.cat([bos_embed, text_embeds], dim=0)
|
| 780 |
+
|
| 781 |
+
pad_text_embeds[:len(text_embeds), :] = text_embeds[:self.max_sequence_length]
|
| 782 |
+
pad_text_mask[:len(text_embeds)] = 1.0
|
| 783 |
+
|
| 784 |
+
image_embeds = hidden_states[np.array(input_ids) == self.processor.tokenizer.image_token_id]
|
| 785 |
+
|
| 786 |
+
"""
|
| 787 |
+
print(input_ids)
|
| 788 |
+
print(input_ids[text_start:text_end + 1])
|
| 789 |
+
decoded = self.processor.decode(input_ids[text_start:text_end+1], skip_special_tokens=False)
|
| 790 |
+
print("Decoded text:", decoded, text_start, text_end, input_ids[text_start:text_end + 1], input_ids[1:2])
|
| 791 |
+
print("Text embeddings shape:", text_embeds.shape)
|
| 792 |
+
norm = RMSNorm(2560, eps=1e-6).to(self.device).to(torch.bfloat16)
|
| 793 |
+
print(text_embeds, ' >>> ext embeds')
|
| 794 |
+
print(norm(text_embeds), ' >>> normed embeds')
|
| 795 |
+
"""
|
| 796 |
+
|
| 797 |
+
return image_embeds, pad_text_embeds, pad_text_mask
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
@torch.no_grad()
|
| 802 |
+
def forward(self, caps, images=None):
|
| 803 |
+
text_embeds = []
|
| 804 |
+
text_masks = []
|
| 805 |
+
full_image_embeds = []
|
| 806 |
+
device = self.model.device
|
| 807 |
+
if images is None:
|
| 808 |
+
images = [None] * len(caps)
|
| 809 |
+
for cap,img in zip(caps, images):
|
| 810 |
+
if img is not None:
|
| 811 |
+
messages = [
|
| 812 |
+
{
|
| 813 |
+
"role": "system",
|
| 814 |
+
"content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_edit}]
|
| 815 |
+
},
|
| 816 |
+
{
|
| 817 |
+
"role": "user",
|
| 818 |
+
"content": [
|
| 819 |
+
{"type": "image", "image": img},
|
| 820 |
+
{"type": "text", "text": cap},
|
| 821 |
+
]
|
| 822 |
+
}
|
| 823 |
+
]
|
| 824 |
+
else:
|
| 825 |
+
messages = [
|
| 826 |
+
{
|
| 827 |
+
"role": "system",
|
| 828 |
+
"content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_t2i}]
|
| 829 |
+
},
|
| 830 |
+
{
|
| 831 |
+
"role": "user",
|
| 832 |
+
"content": [
|
| 833 |
+
{"type": "text", "text": cap},
|
| 834 |
+
]
|
| 835 |
+
}
|
| 836 |
+
]
|
| 837 |
+
|
| 838 |
+
inputs = self.processor.apply_chat_template(
|
| 839 |
+
messages, add_generation_prompt=True, tokenize=True,
|
| 840 |
+
return_dict=True, return_tensors="pt",
|
| 841 |
+
max_length = 640,
|
| 842 |
+
truncation = True,
|
| 843 |
+
).to(self.model.device, dtype=torch.bfloat16)
|
| 844 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
| 845 |
+
#sample_image_embeds = outputs.image_hidden_states
|
| 846 |
+
sample_text_embeds, sample_text_mask, sample_image_embeds = [], [], []
|
| 847 |
+
for hidden in [outputs.hidden_states[-1]]:
|
| 848 |
+
cur_image_embeds, cur_text_embeds, cur_text_mask = self.get_features(hidden, inputs["input_ids"])
|
| 849 |
+
sample_text_embeds.append(cur_text_embeds)
|
| 850 |
+
sample_text_mask.append(cur_text_mask)
|
| 851 |
+
sample_image_embeds.append(cur_image_embeds)
|
| 852 |
+
text_embeds.append(torch.cat(sample_text_embeds, dim=0))
|
| 853 |
+
text_masks.append(torch.cat(sample_text_mask, dim=0))
|
| 854 |
+
#full_image_embeds.append(sample_image_embeds)
|
| 855 |
+
full_image_embeds.append(torch.cat(sample_image_embeds, dim=0))
|
| 856 |
+
|
| 857 |
+
"""
|
| 858 |
+
input_len = inputs["input_ids"].shape[-1]
|
| 859 |
+
with torch.inference_mode():
|
| 860 |
+
generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
| 861 |
+
generation = generation[0][input_len:]
|
| 862 |
+
|
| 863 |
+
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
| 864 |
+
print(cap, ' <>>> gemma ',decoded)
|
| 865 |
+
"""
|
| 866 |
+
|
| 867 |
+
text_embeds = torch.stack(text_embeds, dim=0)
|
| 868 |
+
text_masks = torch.stack(text_masks, dim=0)
|
| 869 |
+
full_image_embeds = torch.stack(full_image_embeds, dim=0)
|
| 870 |
+
return {
|
| 871 |
+
'text_embeds': text_embeds,
|
| 872 |
+
'text_masks': text_masks,
|
| 873 |
+
'image_embeds': full_image_embeds,
|
| 874 |
+
}
|
| 875 |
+
|
| 876 |
+
class GemmaTextEmbedder(nn.Module):
|
| 877 |
+
def __init__(self, device, max_sequence_length=300, model_id='./gemma-3-4b-it'):
|
| 878 |
+
super().__init__()
|
| 879 |
+
self.model = Gemma3Model.from_pretrained(model_id).to(device).to(torch.bfloat16)
|
| 880 |
+
#self.model = Gemma3ForConditionalGeneration.from_pretrained(model_id).to(device).to(torch.bfloat16)
|
| 881 |
+
self.processor = AutoProcessor.from_pretrained(model_id)
|
| 882 |
+
self.real_device = device
|
| 883 |
+
self.max_sequence_length = max_sequence_length
|
| 884 |
+
#self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token # Use eos token as pad token
|
| 885 |
+
self.processor.tokenizer.padding_side = "right"
|
| 886 |
+
|
| 887 |
+
@property
|
| 888 |
+
def dtype(self):
|
| 889 |
+
"""Return the dtype of the model parameters."""
|
| 890 |
+
return next(self.parameters()).dtype
|
| 891 |
+
|
| 892 |
+
@property
|
| 893 |
+
def device(self):
|
| 894 |
+
"""Return the device of the model parameters."""
|
| 895 |
+
return next(self.parameters()).device
|
| 896 |
+
|
| 897 |
+
def get_features(self, hidden_states, input_ids):
|
| 898 |
+
hidden_states = hidden_states[0]
|
| 899 |
+
input_ids = input_ids[0].tolist()
|
| 900 |
+
|
| 901 |
+
pad_text_embeds = torch.zeros([self.max_sequence_length, 2560], dtype=torch.bfloat16, device=self.device)
|
| 902 |
+
pad_text_mask = torch.zeros([self.max_sequence_length], device=self.device)
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
def find_last(lst, value):
|
| 906 |
+
indices = [i for i, x in enumerate(lst) if x == value]
|
| 907 |
+
return indices[-1]
|
| 908 |
+
|
| 909 |
+
if 256000 in input_ids:
|
| 910 |
+
text_start = input_ids.index(256000)+2
|
| 911 |
+
else:
|
| 912 |
+
text_start = find_last(input_ids, 108)+1
|
| 913 |
+
|
| 914 |
+
text_end = len(input_ids)-6
|
| 915 |
+
bos_embed = hidden_states[:2]
|
| 916 |
+
text_embeds = hidden_states[text_start:text_end + 1]
|
| 917 |
+
text_embeds = torch.cat([bos_embed, text_embeds], dim=0)
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
pad_text_embeds[:len(text_embeds), :] = text_embeds[:self.max_sequence_length]
|
| 921 |
+
pad_text_mask[:len(text_embeds)] = 1.0
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
pad_text_embeds[len(text_embeds):, :] = 0.0
|
| 925 |
+
pad_text_mask[len(text_embeds):] = 0.0
|
| 926 |
+
|
| 927 |
+
"""
|
| 928 |
+
print(input_ids)
|
| 929 |
+
print(input_ids[text_start:text_end + 1])
|
| 930 |
+
decoded = self.processor.decode(input_ids[text_start:text_end+1], skip_special_tokens=False)
|
| 931 |
+
print("Decoded text:", decoded, text_start, text_end, input_ids[text_start:text_end + 1], input_ids[1:2])
|
| 932 |
+
print("Text embeddings shape:", text_embeds.shape)
|
| 933 |
+
print(text_embeds, ' >>> ext embeds')
|
| 934 |
+
"""
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
return pad_text_embeds, pad_text_mask
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
@torch.no_grad()
|
| 942 |
+
def forward(self, caps, images=None):
|
| 943 |
+
text_embeds = []
|
| 944 |
+
text_masks = []
|
| 945 |
+
full_image_embeds = []
|
| 946 |
+
device = self.model.device
|
| 947 |
+
if isinstance(caps, str):
|
| 948 |
+
caps = [caps]
|
| 949 |
+
if images is None:
|
| 950 |
+
images = [None] * len(caps)
|
| 951 |
+
for cap,img in zip(caps, images):
|
| 952 |
+
if img is not None:
|
| 953 |
+
messages = [
|
| 954 |
+
{
|
| 955 |
+
"role": "system",
|
| 956 |
+
"content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_edit}]
|
| 957 |
+
},
|
| 958 |
+
{
|
| 959 |
+
"role": "user",
|
| 960 |
+
"content": [
|
| 961 |
+
{"type": "image", "image": img},
|
| 962 |
+
{"type": "text", "text": cap},
|
| 963 |
+
]
|
| 964 |
+
}
|
| 965 |
+
]
|
| 966 |
+
else:
|
| 967 |
+
messages = [
|
| 968 |
+
{
|
| 969 |
+
"role": "system",
|
| 970 |
+
"content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_t2i}]
|
| 971 |
+
},
|
| 972 |
+
{
|
| 973 |
+
"role": "user",
|
| 974 |
+
"content": [
|
| 975 |
+
{"type": "text", "text": cap},
|
| 976 |
+
]
|
| 977 |
+
}
|
| 978 |
+
]
|
| 979 |
+
|
| 980 |
+
inputs = self.processor.apply_chat_template(
|
| 981 |
+
messages, add_generation_prompt=True, tokenize=True,
|
| 982 |
+
return_dict=True, return_tensors="pt",
|
| 983 |
+
max_length = 640,
|
| 984 |
+
truncation = True,
|
| 985 |
+
).to(self.model.device, dtype=torch.bfloat16)
|
| 986 |
+
outputs = self.model(**inputs, output_hidden_states=True)
|
| 987 |
+
#sample_image_embeds = outputs.image_hidden_states
|
| 988 |
+
sample_text_embeds, sample_text_mask, sample_image_embeds = [], [], []
|
| 989 |
+
for hidden in [outputs.hidden_states[-1]]:
|
| 990 |
+
cur_text_embeds, cur_text_mask = self.get_features(hidden, inputs["input_ids"])
|
| 991 |
+
sample_text_embeds.append(cur_text_embeds)
|
| 992 |
+
sample_text_mask.append(cur_text_mask)
|
| 993 |
+
text_embeds.append(torch.cat(sample_text_embeds, dim=0))
|
| 994 |
+
text_masks.append(torch.cat(sample_text_mask, dim=0))
|
| 995 |
+
|
| 996 |
+
"""
|
| 997 |
+
input_len = inputs["input_ids"].shape[-1]
|
| 998 |
+
with torch.inference_mode():
|
| 999 |
+
generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
| 1000 |
+
generation = generation[0][input_len:]
|
| 1001 |
+
|
| 1002 |
+
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
| 1003 |
+
print(cap, ' <>>> gemma ',decoded)
|
| 1004 |
+
"""
|
| 1005 |
+
|
| 1006 |
+
text_embeds = torch.stack(text_embeds, dim=0)
|
| 1007 |
+
text_masks = torch.stack(text_masks, dim=0)
|
| 1008 |
+
return text_embeds, text_masks.to(text_embeds.dtype)
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
from transformers import AutoModel, AutoTokenizer
|
| 1013 |
+
from transformers import SiglipVisionModel, AutoProcessor
|
| 1014 |
+
class Gemma2Embedder(nn.Module):
|
| 1015 |
+
def __init__(self, max_length=300):
|
| 1016 |
+
super().__init__()
|
| 1017 |
+
self.text_encoder = AutoModel.from_pretrained(
|
| 1018 |
+
"google/gemma-2-2b",
|
| 1019 |
+
torch_dtype=torch.bfloat16,
|
| 1020 |
+
).to(torch.cuda.current_device()).to(torch.bfloat16).eval()
|
| 1021 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 1022 |
+
"google/gemma-2-2b",
|
| 1023 |
+
)
|
| 1024 |
+
self.tokenizer.padding_side = "right"
|
| 1025 |
+
self.max_length = max_length
|
| 1026 |
+
self.system_prompt = "You are an assistant designed to edit images faithfully based on user prompts. <Prompt Start> "
|
| 1027 |
+
system_ids = self.tokenizer(
|
| 1028 |
+
self.system_prompt,
|
| 1029 |
+
return_tensors="pt",
|
| 1030 |
+
add_special_tokens=True,
|
| 1031 |
+
max_length=self.max_length,
|
| 1032 |
+
padding="max_length",
|
| 1033 |
+
truncation=True,
|
| 1034 |
+
).input_ids.flatten().view(-1).numpy().tolist()
|
| 1035 |
+
self.len_system_prompt = system_ids.index(self.tokenizer.pad_token_id)-1
|
| 1036 |
+
self.weight_dtype = torch.bfloat16
|
| 1037 |
+
|
| 1038 |
+
@torch.no_grad()
|
| 1039 |
+
def forward(self, caption):
|
| 1040 |
+
if isinstance(caption, str):
|
| 1041 |
+
caption = [caption]
|
| 1042 |
+
caption = [self.system_prompt + c for c in caption]
|
| 1043 |
+
text_inputs = self.tokenizer(
|
| 1044 |
+
caption,
|
| 1045 |
+
return_tensors="pt",
|
| 1046 |
+
add_special_tokens=True,
|
| 1047 |
+
max_length=self.max_length+self.len_system_prompt,
|
| 1048 |
+
padding="max_length",
|
| 1049 |
+
truncation=True,
|
| 1050 |
+
)
|
| 1051 |
+
text_input_ids = text_inputs.input_ids
|
| 1052 |
+
attention_mask = text_inputs.attention_mask
|
| 1053 |
+
text_input_ids = text_input_ids.to(self.text_encoder.device)
|
| 1054 |
+
attention_mask = attention_mask.to(self.text_encoder.device)
|
| 1055 |
+
embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask,
|
| 1056 |
+
output_hidden_states=True
|
| 1057 |
+
).hidden_states[-2]
|
| 1058 |
+
embeds = embeds[:, self.len_system_prompt:, :]
|
| 1059 |
+
attention_mask = attention_mask[:, self.len_system_prompt:]
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
return {
|
| 1063 |
+
'text_embeds': embeds,
|
| 1064 |
+
'text_masks': attention_mask,
|
| 1065 |
+
}
|
| 1066 |
+
|
| 1067 |
+
|
| 1068 |
+
class T5TextEmbedder(nn.Module):
|
| 1069 |
+
def __init__(self, device, pretrained_path="google/flan-t5-xxl", max_length=300):
|
| 1070 |
+
super().__init__()
|
| 1071 |
+
self.model = T5EncoderModel.from_pretrained(pretrained_path).to(device=device).to(torch.bfloat16)
|
| 1072 |
+
self.tokenizer = T5Tokenizer.from_pretrained(pretrained_path)
|
| 1073 |
+
self.max_length = max_length
|
| 1074 |
+
self.model.eval()
|
| 1075 |
+
self.model.requires_grad_(False)
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
@property
|
| 1079 |
+
def dtype(self):
|
| 1080 |
+
"""Return the dtype of the model parameters."""
|
| 1081 |
+
return next(self.parameters()).dtype
|
| 1082 |
+
|
| 1083 |
+
@property
|
| 1084 |
+
def device(self):
|
| 1085 |
+
"""Return the device of the model parameters."""
|
| 1086 |
+
return next(self.parameters()).device
|
| 1087 |
+
|
| 1088 |
+
def forward(
|
| 1089 |
+
self, caption
|
| 1090 |
+
):
|
| 1091 |
+
max_length = self.max_length
|
| 1092 |
+
|
| 1093 |
+
text_inputs = self.tokenizer(
|
| 1094 |
+
caption,
|
| 1095 |
+
return_tensors="pt",
|
| 1096 |
+
add_special_tokens=True,
|
| 1097 |
+
max_length=max_length,
|
| 1098 |
+
padding="max_length",
|
| 1099 |
+
truncation=True,
|
| 1100 |
+
)
|
| 1101 |
+
text_input_ids = text_inputs.input_ids
|
| 1102 |
+
attention_mask = text_inputs.attention_mask
|
| 1103 |
+
text_input_ids = text_input_ids.to(self.model.device)
|
| 1104 |
+
attention_mask = attention_mask.to(self.model.device)
|
| 1105 |
+
outputs = self.model(text_input_ids, attention_mask=attention_mask)
|
| 1106 |
+
embeddings = outputs.last_hidden_state
|
| 1107 |
+
return embeddings, attention_mask.to(embeddings.dtype)
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
if __name__ == '__main__':
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
from datasets import load_dataset
|
| 1119 |
+
dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]')
|
| 1120 |
+
item = dataset[0:4]
|
| 1121 |
+
another_item = dataset[0:4]
|
| 1122 |
+
from diffusers.models.normalization import RMSNorm
|
| 1123 |
+
image_encoder = CLIPImageEncoder(device="cuda:0")
|
| 1124 |
+
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 1125 |
+
image_embeds = image_encoder(clip_processor(images=item['image'], return_tensors="pt").pixel_values.to("cuda:0").to(torch.bfloat16))
|
| 1126 |
+
print(image_embeds.shape, ' >>>> image embeds')
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
#model = GemmaTextEmbedder(device="cuda:0")
|
| 1130 |
+
model = LoraT5Embedder(device="cuda:0")
|
| 1131 |
+
prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask = model(
|
| 1132 |
+
[
|
| 1133 |
+
"""A heartwarming 3D rendered scene of
|
| 1134 |
+
an elderly farmer and a tiny orange
|
| 1135 |
+
kitten. The farmer, with a gentle smile,
|
| 1136 |
+
walks alongside the kitten in a lush,
|
| 1137 |
+
green garden filled with thriving plants,
|
| 1138 |
+
showcasing a fruitful harvest. The
|
| 1139 |
+
intricate details of the overalls and the
|
| 1140 |
+
farmer's worn, weathered face tell a
|
| 1141 |
+
story of years spent tending to the land, the farmer is wearing a blue shirt""",
|
| 1142 |
+
],
|
| 1143 |
+
image=clip_processor(images=item['image'], return_tensors="pt").pixel_values.to("cuda:0").to(torch.bfloat16
|
| 1144 |
+
))
|
| 1145 |
+
print(l2_loss, ' >>> l2 loss ', l1_loss, ' >>> l1 loss ', pad_loss, ' >>> pad loss ')
|
| 1146 |
+
print(clip_image_embeds.shape, ' >>> clip image embeds ')
|
| 1147 |
+
|
| 1148 |
+
#print(gemma_dict['text_embeds'],)
|
| 1149 |
+
#print(gemma_dict['image_embeds'], ' >>> image embeds')
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
"""
|
| 1154 |
+
from dataset import create_loader
|
| 1155 |
+
from PIL import Image as PILImage
|
| 1156 |
+
from PIL import Image as PILImage
|
| 1157 |
+
import PIL
|
| 1158 |
+
import numpy as np
|
| 1159 |
+
import torch.nn.functional as F
|
| 1160 |
+
|
| 1161 |
+
loader = create_loader('edit', batch_size=16, shuffle=False)
|
| 1162 |
+
batch = next(iter(loader))
|
| 1163 |
+
source = batch['source_images']
|
| 1164 |
+
source_pils = [PIL.Image.fromarray(((x.permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype(np.uint8)) for x in source]
|
| 1165 |
+
target = batch['target_images']
|
| 1166 |
+
target_pils = [PIL.Image.fromarray(((x.permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype(np.uint8)) for x in target]
|
| 1167 |
+
from torchvision.utils import save_image
|
| 1168 |
+
|
| 1169 |
+
print(batch['captions'])
|
| 1170 |
+
|
| 1171 |
+
|
| 1172 |
+
images = []
|
| 1173 |
+
for (x, y) in zip(batch['source_images'], batch['target_images']):
|
| 1174 |
+
images.append(x)
|
| 1175 |
+
images.append(y)
|
| 1176 |
+
save_image((torch.stack(images) + 1) / 2, 'example_pairs.jpg', nrow=8)
|
| 1177 |
+
|
| 1178 |
+
gemma_dict = model(batch['captions'], source_pils, target_pils)
|
| 1179 |
+
image_embeds = gemma_dict['image_embeds']
|
| 1180 |
+
target_image_embeds = gemma_dict['target_image_embeds']
|
| 1181 |
+
|
| 1182 |
+
print("Image embeds shape:", image_embeds.shape)
|
| 1183 |
+
print("Target image embeds shape:", target_image_embeds.shape)
|
| 1184 |
+
from qwen import compute_and_save_similarity_grid
|
| 1185 |
+
compute_and_save_similarity_grid(image_embeds, target_image_embeds, "gemma_similarity_grid.jpg")
|
| 1186 |
+
"""
|
| 1187 |
+
|
| 1188 |
+
|