pixagram-backup / resampler_compatible.py
primerz's picture
Upload 12 files
fe30f16 verified
raw
history blame
7.26 kB
"""
Torch 2.0 Optimized Resampler - Maintains InstantID Weight Compatibility
==========================================================================
Key principle: Keep EXACT same architecture as original for weight loading,
but optimize with torch 2.0 features for better performance.
Changes from base:
- Torch 2.0 scaled_dot_product_attention (faster, less memory)
- Better numerical stability
- NO architecture changes (same layers, heads, dims)
Author: Pixagram Team
License: MIT
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def FeedForward(dim, mult=4):
"""Standard feed-forward network."""
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
"""Reshape for multi-head attention."""
bs, length, width = x.shape
x = x.view(bs, length, heads, -1)
x = x.transpose(1, 2)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttentionTorch2(nn.Module):
"""
Perceiver attention with torch 2.0 optimizations.
Architecture IDENTICAL to base for weight compatibility.
"""
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# Check torch 2.0 availability
self.use_torch2 = hasattr(F, "scaled_dot_product_attention")
if self.use_torch2:
print(" [TORCH2] Using optimized scaled_dot_product_attention")
def forward(self, x, latents):
"""
Forward with torch 2.0 optimization when available.
Falls back to manual attention for torch < 2.0.
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# Use torch 2.0 optimized attention if available
if self.use_torch2:
# Reshape for scaled_dot_product_attention: (B, H, L, D)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=self.scale
)
else:
# Fallback to manual attention (torch 1.x)
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class ResamplerCompatible(nn.Module):
"""
Resampler with EXACT same architecture as InstantID pretrained weights.
Optimized for torch 2.0 but maintains full weight compatibility.
DO NOT change:
- dim (1024 default)
- depth (8 layers)
- dim_head (64)
- heads (16)
- num_queries (8 or 4)
These must match the pretrained weights!
"""
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
):
super().__init__()
# Learnable query tokens - SAME initialization as original
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
# Use torch 2.0 optimized attention
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttentionTorch2(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
])
)
print(f"[RESAMPLER] Compatible architecture initialized:")
print(f" - Layers: {depth} (matches pretrained)")
print(f" - Heads: {heads} (matches pretrained)")
print(f" - Dim: {dim} (matches pretrained)")
print(f" - Queries: {num_queries}")
print(f" - Torch 2.0 optimizations: {hasattr(F, 'scaled_dot_product_attention')}")
def forward(self, x):
"""Standard forward pass."""
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def create_compatible_resampler(
num_queries: int = 4,
embedding_dim: int = 512,
output_dim: int = 2048,
device: str = "cuda",
dtype = torch.float16
) -> ResamplerCompatible:
"""
Create Resampler with architecture compatible with InstantID weights.
Args:
num_queries: 4 for IP-Adapter, 8 for original (use 4 for InstantID)
embedding_dim: 512 for InsightFace, 768 for CLIP
output_dim: 2048 for SDXL cross-attention
device: Device
dtype: Data type
"""
# For InstantID with InsightFace embeddings
resampler = ResamplerCompatible(
dim=1024, # MUST match pretrained
depth=8, # MUST match pretrained
dim_head=64, # MUST match pretrained
heads=16, # MUST match pretrained
num_queries=num_queries,
embedding_dim=embedding_dim,
output_dim=output_dim,
ff_mult=4
)
return resampler.to(device, dtype=dtype)
# Backward compatibility
Resampler = ResamplerCompatible
if __name__ == "__main__":
print("Testing Compatible Resampler with Torch 2.0 optimizations...")
resampler = create_compatible_resampler(
num_queries=4,
embedding_dim=512,
output_dim=2048
)
# Test forward pass
test_input = torch.randn(2, 1, 512)
print(f"\nTest input shape: {test_input.shape}")
with torch.no_grad():
output = resampler(test_input)
print(f"Output shape: {output.shape}")
print(f"Expected: [2, 4, 2048]")
assert output.shape == (2, 4, 2048), "Shape mismatch!"
print("\n[OK] Compatible Resampler test passed!")
# Check torch 2.0
if hasattr(F, "scaled_dot_product_attention"):
print("[OK] Using torch 2.0 optimizations")
else:
print("[INFO] Torch 2.0 not available, using fallback")