Spaces:
Runtime error
Runtime error
| """ | |
| Torch 2.0 Optimized Resampler - Maintains InstantID Weight Compatibility | |
| ========================================================================== | |
| Key principle: Keep EXACT same architecture as original for weight loading, | |
| but optimize with torch 2.0 features for better performance. | |
| Changes from base: | |
| - Torch 2.0 scaled_dot_product_attention (faster, less memory) | |
| - Better numerical stability | |
| - NO architecture changes (same layers, heads, dims) | |
| Author: Pixagram Team | |
| License: MIT | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def FeedForward(dim, mult=4): | |
| """Standard feed-forward network.""" | |
| inner_dim = int(dim * mult) | |
| return nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, inner_dim, bias=False), | |
| nn.GELU(), | |
| nn.Linear(inner_dim, dim, bias=False), | |
| ) | |
| def reshape_tensor(x, heads): | |
| """Reshape for multi-head attention.""" | |
| bs, length, width = x.shape | |
| x = x.view(bs, length, heads, -1) | |
| x = x.transpose(1, 2) | |
| x = x.reshape(bs, heads, length, -1) | |
| return x | |
| class PerceiverAttentionTorch2(nn.Module): | |
| """ | |
| Perceiver attention with torch 2.0 optimizations. | |
| Architecture IDENTICAL to base for weight compatibility. | |
| """ | |
| def __init__(self, *, dim, dim_head=64, heads=8): | |
| super().__init__() | |
| self.scale = dim_head**-0.5 | |
| self.dim_head = dim_head | |
| self.heads = heads | |
| inner_dim = dim_head * heads | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
| self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) | |
| self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
| # Check torch 2.0 availability | |
| self.use_torch2 = hasattr(F, "scaled_dot_product_attention") | |
| if self.use_torch2: | |
| print(" [TORCH2] Using optimized scaled_dot_product_attention") | |
| def forward(self, x, latents): | |
| """ | |
| Forward with torch 2.0 optimization when available. | |
| Falls back to manual attention for torch < 2.0. | |
| """ | |
| x = self.norm1(x) | |
| latents = self.norm2(latents) | |
| b, l, _ = latents.shape | |
| q = self.to_q(latents) | |
| kv_input = torch.cat((x, latents), dim=-2) | |
| k, v = self.to_kv(kv_input).chunk(2, dim=-1) | |
| q = reshape_tensor(q, self.heads) | |
| k = reshape_tensor(k, self.heads) | |
| v = reshape_tensor(v, self.heads) | |
| # Use torch 2.0 optimized attention if available | |
| if self.use_torch2: | |
| # Reshape for scaled_dot_product_attention: (B, H, L, D) | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=None, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| scale=self.scale | |
| ) | |
| else: | |
| # Fallback to manual attention (torch 1.x) | |
| scale = 1 / math.sqrt(math.sqrt(self.dim_head)) | |
| weight = (q * scale) @ (k * scale).transpose(-2, -1) | |
| weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
| out = weight @ v | |
| out = out.permute(0, 2, 1, 3).reshape(b, l, -1) | |
| return self.to_out(out) | |
| class ResamplerCompatible(nn.Module): | |
| """ | |
| Resampler with EXACT same architecture as InstantID pretrained weights. | |
| Optimized for torch 2.0 but maintains full weight compatibility. | |
| DO NOT change: | |
| - dim (1024 default) | |
| - depth (8 layers) | |
| - dim_head (64) | |
| - heads (16) | |
| - num_queries (8 or 4) | |
| These must match the pretrained weights! | |
| """ | |
| def __init__( | |
| self, | |
| dim=1024, | |
| depth=8, | |
| dim_head=64, | |
| heads=16, | |
| num_queries=8, | |
| embedding_dim=768, | |
| output_dim=1024, | |
| ff_mult=4, | |
| ): | |
| super().__init__() | |
| # Learnable query tokens - SAME initialization as original | |
| self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) | |
| self.proj_in = nn.Linear(embedding_dim, dim) | |
| self.proj_out = nn.Linear(dim, output_dim) | |
| self.norm_out = nn.LayerNorm(output_dim) | |
| # Use torch 2.0 optimized attention | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList([ | |
| PerceiverAttentionTorch2(dim=dim, dim_head=dim_head, heads=heads), | |
| FeedForward(dim=dim, mult=ff_mult), | |
| ]) | |
| ) | |
| print(f"[RESAMPLER] Compatible architecture initialized:") | |
| print(f" - Layers: {depth} (matches pretrained)") | |
| print(f" - Heads: {heads} (matches pretrained)") | |
| print(f" - Dim: {dim} (matches pretrained)") | |
| print(f" - Queries: {num_queries}") | |
| print(f" - Torch 2.0 optimizations: {hasattr(F, 'scaled_dot_product_attention')}") | |
| def forward(self, x): | |
| """Standard forward pass.""" | |
| latents = self.latents.repeat(x.size(0), 1, 1) | |
| x = self.proj_in(x) | |
| for attn, ff in self.layers: | |
| latents = attn(x, latents) + latents | |
| latents = ff(latents) + latents | |
| latents = self.proj_out(latents) | |
| return self.norm_out(latents) | |
| def create_compatible_resampler( | |
| num_queries: int = 4, | |
| embedding_dim: int = 512, | |
| output_dim: int = 2048, | |
| device: str = "cuda", | |
| dtype = torch.float16 | |
| ) -> ResamplerCompatible: | |
| """ | |
| Create Resampler with architecture compatible with InstantID weights. | |
| Args: | |
| num_queries: 4 for IP-Adapter, 8 for original (use 4 for InstantID) | |
| embedding_dim: 512 for InsightFace, 768 for CLIP | |
| output_dim: 2048 for SDXL cross-attention | |
| device: Device | |
| dtype: Data type | |
| """ | |
| # For InstantID with InsightFace embeddings | |
| resampler = ResamplerCompatible( | |
| dim=1024, # MUST match pretrained | |
| depth=8, # MUST match pretrained | |
| dim_head=64, # MUST match pretrained | |
| heads=16, # MUST match pretrained | |
| num_queries=num_queries, | |
| embedding_dim=embedding_dim, | |
| output_dim=output_dim, | |
| ff_mult=4 | |
| ) | |
| return resampler.to(device, dtype=dtype) | |
| # Backward compatibility | |
| Resampler = ResamplerCompatible | |
| if __name__ == "__main__": | |
| print("Testing Compatible Resampler with Torch 2.0 optimizations...") | |
| resampler = create_compatible_resampler( | |
| num_queries=4, | |
| embedding_dim=512, | |
| output_dim=2048 | |
| ) | |
| # Test forward pass | |
| test_input = torch.randn(2, 1, 512) | |
| print(f"\nTest input shape: {test_input.shape}") | |
| with torch.no_grad(): | |
| output = resampler(test_input) | |
| print(f"Output shape: {output.shape}") | |
| print(f"Expected: [2, 4, 2048]") | |
| assert output.shape == (2, 4, 2048), "Shape mismatch!" | |
| print("\n[OK] Compatible Resampler test passed!") | |
| # Check torch 2.0 | |
| if hasattr(F, "scaled_dot_product_attention"): | |
| print("[OK] Using torch 2.0 optimizations") | |
| else: | |
| print("[INFO] Torch 2.0 not available, using fallback") | |