Upload 9 files
Browse files- .gitattributes +2 -0
- Infinity/infinity/models/basic.py +793 -0
- Infinity/infinity/models/infinity.py +817 -0
- Infinity/infinity_vae_d32_reg.pth +3 -0
- README.md +162 -3
- flan-t5-xl-encoder-Q8_0.gguf +3 -0
- generate_image_2b_q8_gguf.py +559 -0
- gradio_webui.py +342 -0
- infinity_2b_reg_Q8_0.gguf +3 -0
- infinity_gguf_utils.py +477 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
flan-t5-xl-encoder-Q8_0.gguf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
infinity_2b_reg_Q8_0.gguf filter=lfs diff=lfs merge=lfs -text
|
Infinity/infinity/models/basic.py
ADDED
|
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Definitions of blocks of VAR transformer model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import numpy as np
|
| 14 |
+
from timm.models.layers import DropPath, drop_path
|
| 15 |
+
from torch.utils.checkpoint import checkpoint
|
| 16 |
+
|
| 17 |
+
# Attention backend selection with fallback hierarchy:
|
| 18 |
+
# 1. SageAttention (optional, 2-5x faster than FlashAttention)
|
| 19 |
+
# 2. FlashAttention (optional, still faster than PyTorch)
|
| 20 |
+
# 3. PyTorch scaled_dot_product_attention (always available)
|
| 21 |
+
|
| 22 |
+
SAGE_ATTN_AVAILABLE = False
|
| 23 |
+
FLASH_ATTN_AVAILABLE = False
|
| 24 |
+
sageattn = None
|
| 25 |
+
sageattn_varlen = None
|
| 26 |
+
flash_attn_func = None
|
| 27 |
+
flash_attn_varlen_kvpacked_func = None
|
| 28 |
+
|
| 29 |
+
# Try to import SageAttention (optional, fastest option)
|
| 30 |
+
try:
|
| 31 |
+
from sageattention import sageattn, sageattn_varlen
|
| 32 |
+
SAGE_ATTN_AVAILABLE = True
|
| 33 |
+
print("[INFO] SageAttention detected - will use for 2-5x speedup over FlashAttention")
|
| 34 |
+
except ImportError:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
# Try to import FlashAttention (optional, fallback if SageAttention not available)
|
| 38 |
+
try:
|
| 39 |
+
from flash_attn import flash_attn_func # q, k, or v: BLHc, ret: BLHc
|
| 40 |
+
from flash_attn import flash_attn_varlen_kvpacked_func # qkv: N3Hc, ret: NHc
|
| 41 |
+
FLASH_ATTN_AVAILABLE = True
|
| 42 |
+
if not SAGE_ATTN_AVAILABLE:
|
| 43 |
+
print("[INFO] FlashAttention detected - will use for optimized attention")
|
| 44 |
+
except ImportError:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
# Print final status
|
| 48 |
+
if not SAGE_ATTN_AVAILABLE and not FLASH_ATTN_AVAILABLE:
|
| 49 |
+
print("[INFO] Using PyTorch scaled_dot_product_attention (no SageAttention or FlashAttention detected)")
|
| 50 |
+
print(" Install SageAttention for 2-5x speedup: pip install sageattention>=2.2.0 --no-build-isolation")
|
| 51 |
+
|
| 52 |
+
from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
|
| 53 |
+
|
| 54 |
+
# Import GGUF utilities for on-the-fly dequantization
|
| 55 |
+
try:
|
| 56 |
+
import sys
|
| 57 |
+
import os
|
| 58 |
+
# Add parent directory to path to find infinity_gguf_utils
|
| 59 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 60 |
+
parent_dirs = [
|
| 61 |
+
os.path.join(current_dir, '../../..'), # From Infinity/infinity/models to root
|
| 62 |
+
os.path.join(current_dir, '../../../..'), # One more level up if needed
|
| 63 |
+
]
|
| 64 |
+
for parent_dir in parent_dirs:
|
| 65 |
+
if parent_dir not in sys.path:
|
| 66 |
+
sys.path.insert(0, parent_dir)
|
| 67 |
+
from infinity_gguf_utils import dequantize_gguf_tensor, GGUFParameter
|
| 68 |
+
GGUF_AVAILABLE = True
|
| 69 |
+
except ImportError:
|
| 70 |
+
GGUF_AVAILABLE = False
|
| 71 |
+
GGUFParameter = None
|
| 72 |
+
|
| 73 |
+
def get_weight_for_linear(linear_layer, target_dtype=None):
|
| 74 |
+
"""
|
| 75 |
+
Helper function to get weight from a linear layer, dequantizing if it's a GGUF parameter.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
linear_layer: nn.Linear or GGUFLinear layer
|
| 79 |
+
target_dtype: Target dtype for dequantization
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Weight tensor ready for use in F.linear
|
| 83 |
+
"""
|
| 84 |
+
weight = linear_layer.weight
|
| 85 |
+
if GGUF_AVAILABLE and isinstance(weight, GGUFParameter):
|
| 86 |
+
# Dequantize GGUF weight
|
| 87 |
+
return dequantize_gguf_tensor(weight, target_dtype=target_dtype)
|
| 88 |
+
# For F16 or other non-quantized weights, convert to target dtype if specified
|
| 89 |
+
if target_dtype is not None and weight.dtype != target_dtype:
|
| 90 |
+
return weight.to(dtype=target_dtype)
|
| 91 |
+
return weight
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Import flash_attn's fused ops
|
| 95 |
+
try:
|
| 96 |
+
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
| 97 |
+
from flash_attn.ops.rms_norm import dropout_add_rms_norm
|
| 98 |
+
from flash_attn.ops.rms_norm import rms_norm as rms_norm_impl
|
| 99 |
+
from flash_attn.ops.fused_dense import fused_mlp_func
|
| 100 |
+
flash_fused_op_installed = True
|
| 101 |
+
except ImportError:
|
| 102 |
+
dropout_add_layer_norm = dropout_add_rms_norm = fused_mlp_func = None
|
| 103 |
+
flash_fused_op_installed = False
|
| 104 |
+
|
| 105 |
+
def rms_norm_impl(x, weight, epsilon):
|
| 106 |
+
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(epsilon))) * weight
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0):
|
| 110 |
+
# split the dimension into half, one for x and one for y
|
| 111 |
+
half_dim = dim // 2
|
| 112 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
|
| 113 |
+
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
|
| 114 |
+
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
|
| 115 |
+
t_height = t_height / scaling_factor
|
| 116 |
+
freqs_height = torch.outer(t_height, inv_freq) # (max_height, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely y*theta
|
| 117 |
+
t_width = t_width / scaling_factor
|
| 118 |
+
freqs_width = torch.outer(t_width, inv_freq) # (max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely x*theta
|
| 119 |
+
freqs_grid_map = torch.concat([
|
| 120 |
+
freqs_height[:, None, :].expand(-1, max_width, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
|
| 121 |
+
freqs_width[None, :, :].expand(max_height, -1, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
|
| 122 |
+
], dim=-1) # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
|
| 123 |
+
freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
|
| 124 |
+
# (2, max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
|
| 125 |
+
|
| 126 |
+
rope2d_freqs_grid = {}
|
| 127 |
+
for h_div_w in dynamic_resolution_h_w:
|
| 128 |
+
scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['scales']
|
| 129 |
+
_, ph, pw = scale_schedule[-1]
|
| 130 |
+
max_edge_length = freqs_grid_map.shape[1]
|
| 131 |
+
if ph >= pw:
|
| 132 |
+
uph, upw = max_edge_length, int(max_edge_length / ph * pw)
|
| 133 |
+
else:
|
| 134 |
+
uph, upw = int(max_edge_length / pw * ph), max_edge_length
|
| 135 |
+
rope_cache_list = []
|
| 136 |
+
for (_, ph, pw) in scale_schedule:
|
| 137 |
+
ph_mul_pw = ph * pw
|
| 138 |
+
if rope2d_normalized_by_hw == 1: # downsample
|
| 139 |
+
rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True)
|
| 140 |
+
rope_cache = rope_cache.permute([0,2,3,1]) # (2, ph, pw, half_head_dim)
|
| 141 |
+
elif rope2d_normalized_by_hw == 2: # star stylee
|
| 142 |
+
_, uph, upw = scale_schedule[-1]
|
| 143 |
+
indices = torch.stack([
|
| 144 |
+
(torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw),
|
| 145 |
+
(torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw),
|
| 146 |
+
], dim=-1).round().int() # (ph, pw, 2)
|
| 147 |
+
indices = indices.reshape(-1, 2) # (ph*pw, 2)
|
| 148 |
+
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], :] # (2, ph*pw, half_head_dim)
|
| 149 |
+
rope_cache = rope_cache.reshape(2, ph, pw, -1)
|
| 150 |
+
elif rope2d_normalized_by_hw == 0:
|
| 151 |
+
rope_cache = freqs_grid_map[:, :ph, :pw, :] # (2, ph, pw, half_head_dim)
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
|
| 154 |
+
rope_cache_list.append(rope_cache.reshape(2, ph_mul_pw, -1))
|
| 155 |
+
cat_rope_cache = torch.cat(rope_cache_list, 1) # (2, seq_len, half_head_dim)
|
| 156 |
+
if cat_rope_cache.shape[1] % pad_to_multiplier:
|
| 157 |
+
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim)
|
| 158 |
+
cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1)
|
| 159 |
+
cat_rope_cache = cat_rope_cache[:,None,None,None] # (2, 1, 1, 1, seq_len, half_dim)
|
| 160 |
+
for pn in dynamic_resolution_h_w[h_div_w]:
|
| 161 |
+
scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['scales']
|
| 162 |
+
tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule]
|
| 163 |
+
rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache
|
| 164 |
+
return rope2d_freqs_grid
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, pad_to_multiplier, rope2d_normalized_by_hw, scale_ind):
|
| 168 |
+
qk = torch.stack((q, k), dim=0) #(2, batch_size, heads, seq_len, head_dim)
|
| 169 |
+
device_type = qk.device.type
|
| 170 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 171 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 172 |
+
seq_len = qk.shape[3]
|
| 173 |
+
start = 0
|
| 174 |
+
if scale_ind >= 1:
|
| 175 |
+
assert len(scale_schedule[0]) == 3
|
| 176 |
+
start = np.sum([item[0] * item[1] * item[2] for item in scale_schedule[:scale_ind]])
|
| 177 |
+
rope2d_freqs_grid[str(tuple(scale_schedule))] = rope2d_freqs_grid[str(tuple(scale_schedule))].to(qk.device)
|
| 178 |
+
assert start+seq_len <= rope2d_freqs_grid[str(tuple(scale_schedule))].shape[4]
|
| 179 |
+
rope_cache = rope2d_freqs_grid[str(tuple(scale_schedule))][:, :, :, :, start:start+seq_len] # rope_cache shape: [2, 1, 1, 1, seq_len, half_head_dim]
|
| 180 |
+
qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2)
|
| 181 |
+
qk = torch.stack([
|
| 182 |
+
rope_cache[0] * qk[...,0] - rope_cache[1] * qk[...,1],
|
| 183 |
+
rope_cache[1] * qk[...,0] + rope_cache[0] * qk[...,1],
|
| 184 |
+
], dim=-1) # (2, batch_size, heads, seq_len, half_head_dim, 2), here stack + reshape should not be concate
|
| 185 |
+
qk = qk.reshape(*qk.shape[:-2], -1) #(2, batch_size, heads, seq_len, head_dim)
|
| 186 |
+
q, k = qk.unbind(dim=0) # (batch_size, heads, seq_len, head_dim)
|
| 187 |
+
return q, k
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class FastRMSNorm(nn.Module):
|
| 191 |
+
def __init__(self, C, eps=1e-6, elementwise_affine=True):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.C = C
|
| 194 |
+
self.eps = eps
|
| 195 |
+
self.elementwise_affine = elementwise_affine
|
| 196 |
+
if self.elementwise_affine:
|
| 197 |
+
self.weight = nn.Parameter(torch.ones(C))
|
| 198 |
+
else:
|
| 199 |
+
self.register_buffer('weight', torch.ones(C))
|
| 200 |
+
|
| 201 |
+
def forward(self, x):
|
| 202 |
+
src_type = x.dtype
|
| 203 |
+
return rms_norm_impl(x.float(), self.weight, epsilon=self.eps).to(src_type)
|
| 204 |
+
|
| 205 |
+
def extra_repr(self) -> str:
|
| 206 |
+
return f'C={self.C}, eps={self.eps:g}, elementwise_affine={self.elementwise_affine}'
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_dropout_layer(p):
|
| 210 |
+
return nn.Dropout(p, inplace=True) if p > 0 else nn.Identity()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class FFN(nn.Module):
|
| 214 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_mlp=False):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.fused_mlp_func = fused_mlp_func if fused_mlp else None
|
| 217 |
+
out_features = out_features or in_features
|
| 218 |
+
hidden_features = hidden_features or in_features
|
| 219 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 220 |
+
self.act = nn.GELU(approximate='tanh')
|
| 221 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 222 |
+
self.drop = get_dropout_layer(drop)
|
| 223 |
+
self.heuristic = -1
|
| 224 |
+
|
| 225 |
+
def forward(self, x):
|
| 226 |
+
if self.fused_mlp_func is not None:
|
| 227 |
+
return self.drop(self.fused_mlp_func(
|
| 228 |
+
x=x,
|
| 229 |
+
weight1=self.fc1.weight,
|
| 230 |
+
weight2=self.fc2.weight,
|
| 231 |
+
bias1=self.fc1.bias,
|
| 232 |
+
bias2=self.fc2.bias,
|
| 233 |
+
activation='gelu_approx',
|
| 234 |
+
save_pre_act=self.training,
|
| 235 |
+
return_residual=False,
|
| 236 |
+
checkpoint_lvl=0,
|
| 237 |
+
heuristic=self.heuristic,
|
| 238 |
+
process_group=None,
|
| 239 |
+
))
|
| 240 |
+
else:
|
| 241 |
+
return self.drop(self.fc2( self.act(self.fc1(x)) ))
|
| 242 |
+
|
| 243 |
+
def extra_repr(self) -> str:
|
| 244 |
+
return f'fused_mlp={self.fused_mlp_func is not None}'
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class FFNSwiGLU(nn.Module):
|
| 248 |
+
def __init__(self, in_features, hidden_features, out_features=None, drop=0., fused_mlp=False):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.fused_mlp_func = None
|
| 251 |
+
hidden_features = round(2 * hidden_features / 3 / 256) * 256
|
| 252 |
+
|
| 253 |
+
out_features = out_features or in_features
|
| 254 |
+
self.fcg = nn.Linear(in_features, hidden_features, bias=False)
|
| 255 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
|
| 256 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
|
| 257 |
+
self.drop = get_dropout_layer(drop)
|
| 258 |
+
|
| 259 |
+
def forward(self, x):
|
| 260 |
+
return self.drop(self.fc2( F.silu(self.fcg(x), inplace=True).mul_(self.fc1(x)) ))
|
| 261 |
+
|
| 262 |
+
def extra_repr(self) -> str:
|
| 263 |
+
return f'fused_mlp={self.fused_mlp_func is not None}'
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class SelfAttention(nn.Module):
|
| 267 |
+
def __init__(
|
| 268 |
+
self, embed_dim=768, num_heads=12,
|
| 269 |
+
proj_drop=0., tau=1, cos_attn=False, customized_flash_attn=True, use_flex_attn=False,
|
| 270 |
+
batch_size=2, pad_to_multiplier=1, rope2d_normalized_by_hw=0,
|
| 271 |
+
):
|
| 272 |
+
"""
|
| 273 |
+
:param embed_dim: model's width
|
| 274 |
+
:param num_heads: num heads of multi-head attention
|
| 275 |
+
:param proj_drop: always 0 for testing
|
| 276 |
+
:param tau: always 1
|
| 277 |
+
:param cos_attn: always True: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
|
| 278 |
+
:param customized_flash_attn:
|
| 279 |
+
"""
|
| 280 |
+
super().__init__()
|
| 281 |
+
assert embed_dim % num_heads == 0
|
| 282 |
+
self.using_flash = customized_flash_attn
|
| 283 |
+
|
| 284 |
+
self.num_heads, self.head_dim = num_heads, embed_dim // num_heads
|
| 285 |
+
self.tau, self.cos_attn = tau, cos_attn
|
| 286 |
+
if self.cos_attn:
|
| 287 |
+
self.scale = 1
|
| 288 |
+
size = (1, 1, self.num_heads, 1) if self.using_flash else (1, self.num_heads, 1, 1)
|
| 289 |
+
# size: 11H1 or 1H11
|
| 290 |
+
self.scale_mul_1H11 = nn.Parameter(torch.full(size=size, fill_value=4.0).log(), requires_grad=True)
|
| 291 |
+
self.max_scale_mul = torch.log(torch.tensor(100)).item()
|
| 292 |
+
else:
|
| 293 |
+
self.scale = 1 / math.sqrt(self.head_dim) / self.tau
|
| 294 |
+
|
| 295 |
+
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
|
| 296 |
+
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
|
| 297 |
+
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
|
| 298 |
+
|
| 299 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
| 300 |
+
self.proj_drop = get_dropout_layer(proj_drop)
|
| 301 |
+
|
| 302 |
+
self.caching = False # kv caching: only used during inference
|
| 303 |
+
self.cached_k = None # kv caching: only used during inference
|
| 304 |
+
self.cached_v = None # kv caching: only used during inference
|
| 305 |
+
|
| 306 |
+
self.batch_size = batch_size
|
| 307 |
+
self.use_flex_attn = use_flex_attn
|
| 308 |
+
self.pad_to_multiplier = pad_to_multiplier
|
| 309 |
+
|
| 310 |
+
self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def kv_caching(self, enable: bool): # kv caching: only used during inference
|
| 314 |
+
self.caching = enable
|
| 315 |
+
self.cached_k = None
|
| 316 |
+
self.cached_v = None
|
| 317 |
+
|
| 318 |
+
# NOTE: attn_bias_or_two_vector is None during inference
|
| 319 |
+
def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.IntTensor, torch.IntTensor]], attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0):
|
| 320 |
+
"""
|
| 321 |
+
:param (fp32) x: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
|
| 322 |
+
:param (fp32) attn_bias_or_two_vector:
|
| 323 |
+
if not using_flash:
|
| 324 |
+
a block-wise, lower-triangle matrix, like:
|
| 325 |
+
[[[[0, -, -, -, -, -, -, -, -, -, -, -, -, -],
|
| 326 |
+
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
|
| 327 |
+
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
|
| 328 |
+
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
|
| 329 |
+
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
|
| 330 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 331 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 332 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 333 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 334 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 335 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 336 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 337 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 338 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]
|
| 339 |
+
where 0 means visible and - means invisible (-inf)
|
| 340 |
+
else:
|
| 341 |
+
a tuple of two 1-dim int vector (VAR_visible_kvlen, VAR_invisible_qlen)
|
| 342 |
+
:return: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
|
| 343 |
+
"""
|
| 344 |
+
# x: fp32
|
| 345 |
+
B, L, C = x.shape
|
| 346 |
+
|
| 347 |
+
# qkv: amp, bf16
|
| 348 |
+
qkv = F.linear(input=x, weight=get_weight_for_linear(self.mat_qkv, target_dtype=x.dtype), bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) # BL3Hc
|
| 349 |
+
if self.using_flash: q, k, v = qkv.unbind(dim=2); L_dim = 1 # q or k or v: all are shaped in (B:batch_size, L:seq_len, H:heads, c:head_dim)
|
| 350 |
+
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); L_dim = 2 # q or k or v: all are shaped in (B:batch_size, H:heads, L:seq_len, c:head_dim)
|
| 351 |
+
|
| 352 |
+
if self.cos_attn: # always True
|
| 353 |
+
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() # 11H1 (flash), or 1H11 (not flash)
|
| 354 |
+
q = F.normalize(q, dim=-1, eps=1e-12).mul(scale_mul).contiguous() # fp32
|
| 355 |
+
k = F.normalize(k, dim=-1, eps=1e-12).contiguous() # fp32
|
| 356 |
+
v = v.contiguous() # bf16
|
| 357 |
+
else: # be contiguous, to make kernel happy
|
| 358 |
+
q = q.contiguous() # bf16
|
| 359 |
+
k = k.contiguous() # bf16
|
| 360 |
+
v = v.contiguous() # bf16
|
| 361 |
+
if rope2d_freqs_grid is not None:
|
| 362 |
+
q, k = apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, self.pad_to_multiplier, self.rope2d_normalized_by_hw, scale_ind) #, freqs_cis=freqs_cis)
|
| 363 |
+
if self.caching: # kv caching: only used during inference
|
| 364 |
+
if self.cached_k is None: self.cached_k = k; self.cached_v = v
|
| 365 |
+
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=L_dim); v = self.cached_v = torch.cat((self.cached_v, v), dim=L_dim)
|
| 366 |
+
|
| 367 |
+
if self.using_flash:
|
| 368 |
+
# Try SageAttention first (if available and during inference)
|
| 369 |
+
if SAGE_ATTN_AVAILABLE and attn_bias_or_two_vector is None:
|
| 370 |
+
try:
|
| 371 |
+
# SageAttention: expects (B, num_heads, seq_len, head_dim) layout (HND format)
|
| 372 |
+
# Our q, k, v are already in (B, L, H, c) format, need to transpose to (B, H, L, c)
|
| 373 |
+
q_sage = q.transpose(1, 2) # (B, H, L, c)
|
| 374 |
+
k_sage = k.transpose(1, 2) # (B, H, L, c)
|
| 375 |
+
v_sage = v.transpose(1, 2) # (B, H, L, c)
|
| 376 |
+
|
| 377 |
+
# Convert to fp16 or bf16 if needed (SageAttention requires fp16/bf16)
|
| 378 |
+
target_dtype = torch.bfloat16 if v.dtype == torch.float32 else v.dtype
|
| 379 |
+
q_sage = q_sage.to(target_dtype)
|
| 380 |
+
k_sage = k_sage.to(target_dtype)
|
| 381 |
+
v_sage = v_sage.to(target_dtype)
|
| 382 |
+
|
| 383 |
+
# Use SageAttention for inference
|
| 384 |
+
oup = sageattn(q_sage, k_sage, v_sage, tensor_layout="HND", is_causal=False)
|
| 385 |
+
oup = oup.transpose(1, 2).reshape(B, L, C) # (B, H, L, c) -> (B, L, H, c) -> (B, L, C)
|
| 386 |
+
if target_dtype != v.dtype:
|
| 387 |
+
oup = oup.to(v.dtype)
|
| 388 |
+
except Exception as e:
|
| 389 |
+
print(f"[WARNING] SageAttention failed ({str(e)[:100]}), falling back to FlashAttention/PyTorch")
|
| 390 |
+
# Fall through to FlashAttention or PyTorch
|
| 391 |
+
if FLASH_ATTN_AVAILABLE:
|
| 392 |
+
kw = dict() if attn_bias_or_two_vector is None else dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1])
|
| 393 |
+
oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C)
|
| 394 |
+
else:
|
| 395 |
+
q_torch = q.transpose(1, 2)
|
| 396 |
+
k_torch = k.transpose(1, 2)
|
| 397 |
+
v_torch = v.transpose(1, 2)
|
| 398 |
+
oup = slow_attn(query=q_torch, key=k_torch, value=v_torch, scale=self.scale, dropout_p=0).transpose(1, 2).reshape(B, L, C)
|
| 399 |
+
|
| 400 |
+
# Fall back to FlashAttention if SageAttention not used
|
| 401 |
+
elif FLASH_ATTN_AVAILABLE:
|
| 402 |
+
if attn_bias_or_two_vector is not None: # training
|
| 403 |
+
kw = dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1])
|
| 404 |
+
else: # inference (autoregressive sampling)
|
| 405 |
+
kw = dict()
|
| 406 |
+
oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C)
|
| 407 |
+
|
| 408 |
+
# Final fallback to PyTorch SDPA
|
| 409 |
+
else:
|
| 410 |
+
q_torch = q.transpose(1, 2) # (B, H, L, c)
|
| 411 |
+
k_torch = k.transpose(1, 2)
|
| 412 |
+
v_torch = v.transpose(1, 2)
|
| 413 |
+
oup = slow_attn(query=q_torch, key=k_torch, value=v_torch, scale=self.scale, dropout_p=0).transpose(1, 2).reshape(B, L, C)
|
| 414 |
+
else:
|
| 415 |
+
# if self.cos_attn: q, k are in fp32; v is in bf16
|
| 416 |
+
# else: q, k, v are in bf16
|
| 417 |
+
if self.use_flex_attn and attn_fn is not None:
|
| 418 |
+
oup = attn_fn(q, k, v, scale=self.scale).transpose(1, 2).reshape(B, L, C)
|
| 419 |
+
else:
|
| 420 |
+
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
|
| 421 |
+
# oup: bf16
|
| 422 |
+
|
| 423 |
+
return self.proj_drop(self.proj(oup))
|
| 424 |
+
|
| 425 |
+
def extra_repr(self) -> str:
|
| 426 |
+
tail = ''
|
| 427 |
+
return f'using_flash={self.using_flash}, tau={self.tau}, cos_attn={self.cos_attn}{tail}'
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class CrossAttention(nn.Module):
|
| 431 |
+
def __init__(
|
| 432 |
+
self, for_attn_pool=False, embed_dim=768, kv_dim=4096, num_heads=12,
|
| 433 |
+
proj_drop=0., cos_attn=False, use_flash_attn=True,
|
| 434 |
+
):
|
| 435 |
+
"""
|
| 436 |
+
:param for_attn_pool: only used in VAR.text_proj_for_sos
|
| 437 |
+
:param embed_dim: Q's dim
|
| 438 |
+
:param kv_dim: K's and V's dim
|
| 439 |
+
:param num_heads: num heads of multi-head attention
|
| 440 |
+
:param proj_drop: proj drop out
|
| 441 |
+
:param cos_attn: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
|
| 442 |
+
"""
|
| 443 |
+
cos_attn = False # TODO: never use cos attn in cross attention with T5 kv
|
| 444 |
+
super().__init__()
|
| 445 |
+
self.for_attn_pool = for_attn_pool
|
| 446 |
+
self.embed_dim = embed_dim
|
| 447 |
+
self.kv_dim = kv_dim
|
| 448 |
+
assert embed_dim % num_heads == 0
|
| 449 |
+
self.num_heads, self.head_dim = num_heads, embed_dim // num_heads # =64
|
| 450 |
+
self.cos_attn = cos_attn
|
| 451 |
+
self.use_flash_attn = use_flash_attn
|
| 452 |
+
if self.cos_attn:
|
| 453 |
+
self.scale = 1
|
| 454 |
+
self.scale_mul_1H1 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
|
| 455 |
+
self.max_scale_mul = torch.log(torch.tensor(100)).item()
|
| 456 |
+
else:
|
| 457 |
+
self.scale = 1 / math.sqrt(self.head_dim)
|
| 458 |
+
|
| 459 |
+
if for_attn_pool:
|
| 460 |
+
q = torch.empty(1, self.num_heads, self.head_dim)
|
| 461 |
+
nn.init.trunc_normal_(q, mean=0, std=math.sqrt(1 / embed_dim / 3))
|
| 462 |
+
self.mat_q = nn.Parameter(q)
|
| 463 |
+
else:
|
| 464 |
+
self.mat_q = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 465 |
+
self.mat_kv = nn.Linear(kv_dim, embed_dim*2, bias=False)
|
| 466 |
+
self.v_bias = nn.Parameter(torch.zeros(embed_dim))
|
| 467 |
+
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
|
| 468 |
+
|
| 469 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
| 470 |
+
self.proj_drop = get_dropout_layer(proj_drop)
|
| 471 |
+
|
| 472 |
+
def forward(self, q, ca_kv):
|
| 473 |
+
"""
|
| 474 |
+
:param q: shaped as (batch, seq_len, Q_dim)
|
| 475 |
+
:param ca_kv: contains several vectors, each of which is shaped as (len_i, KV_dim). We have [len_1xKV_dim, len_2xKV_dim, len_3xKV_dim, ...] and lens == [len_1, len_2, len_3, ...]
|
| 476 |
+
- kv_compact: shaped as (sum(lens), KV_dim)
|
| 477 |
+
- cu_seqlens_k: cumulated sum of lens
|
| 478 |
+
- max_seqlen_k: int, max(lens)
|
| 479 |
+
NOTE: seq_len (num of Qs) can reach 10k; but len_i (num of KVs) must <= 256
|
| 480 |
+
|
| 481 |
+
:return: shaped as (batch, seq_len, Q_dim)
|
| 482 |
+
"""
|
| 483 |
+
kv_compact, cu_seqlens_k, max_seqlen_k = ca_kv
|
| 484 |
+
N = kv_compact.shape[0]
|
| 485 |
+
|
| 486 |
+
kv_compact = F.linear(kv_compact, weight=get_weight_for_linear(self.mat_kv, target_dtype=kv_compact.dtype), bias=torch.cat((self.zero_k_bias, self.v_bias))).view(N, 2, self.num_heads, self.head_dim) # NC => N2Hc
|
| 487 |
+
# attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
| 488 |
+
|
| 489 |
+
if not self.for_attn_pool:
|
| 490 |
+
B, Lq = q.shape[:2]
|
| 491 |
+
q_compact = self.mat_q(q).view(-1, self.num_heads, self.head_dim)
|
| 492 |
+
else:
|
| 493 |
+
B = cu_seqlens_k.shape[0] - 1
|
| 494 |
+
Lq = 1
|
| 495 |
+
# Dequantize mat_q if it's a GGUFParameter
|
| 496 |
+
mat_q_data = self.mat_q
|
| 497 |
+
if GGUF_AVAILABLE and isinstance(mat_q_data, GGUFParameter):
|
| 498 |
+
mat_q_data = dequantize_gguf_tensor(mat_q_data, target_dtype=kv_compact.dtype)
|
| 499 |
+
q_compact = mat_q_data.repeat(B, 1, 1).to(dtype=kv_compact.dtype)
|
| 500 |
+
|
| 501 |
+
if self.cos_attn: # always False
|
| 502 |
+
scale_mul = self.scale_mul_1H1.clamp_max(self.max_scale_mul).exp()
|
| 503 |
+
k, v = kv_compact.unbind(dim=1)
|
| 504 |
+
q_compact = F.normalize(q_compact, dim=-1).mul(scale_mul)
|
| 505 |
+
k = F.normalize(k, dim=-1)
|
| 506 |
+
kv_compact = torch.stack((k, v), dim=1)
|
| 507 |
+
|
| 508 |
+
q_compact = q_compact.contiguous()
|
| 509 |
+
kv_compact = kv_compact.contiguous()
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
# Try optimized attention backends with graceful fallback
|
| 513 |
+
if self.use_flash_attn:
|
| 514 |
+
cu_seqlens_q = torch.arange(0, Lq * (B+1), Lq, dtype=torch.int32, device=q_compact.device)
|
| 515 |
+
oup = None
|
| 516 |
+
|
| 517 |
+
# Try SageAttention first (fastest option)
|
| 518 |
+
if SAGE_ATTN_AVAILABLE:
|
| 519 |
+
try:
|
| 520 |
+
# SageAttention varlen: expects separate k, v tensors
|
| 521 |
+
# kv_compact is (N, 2, num_heads, head_dim), split into k and v
|
| 522 |
+
k_compact, v_compact = kv_compact.unbind(dim=1) # Each is (N, num_heads, head_dim)
|
| 523 |
+
|
| 524 |
+
# Convert to fp16/bf16 if needed
|
| 525 |
+
target_dtype = torch.bfloat16 if q_compact.dtype == torch.float32 else q_compact.dtype
|
| 526 |
+
q_sage = q_compact.to(target_dtype)
|
| 527 |
+
k_sage = k_compact.to(target_dtype)
|
| 528 |
+
v_sage = v_compact.to(target_dtype)
|
| 529 |
+
|
| 530 |
+
# Use sageattn_varlen for variable length sequences
|
| 531 |
+
oup = sageattn_varlen(
|
| 532 |
+
q=q_sage,
|
| 533 |
+
k=k_sage,
|
| 534 |
+
v=v_sage,
|
| 535 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 536 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 537 |
+
max_seqlen_q=Lq,
|
| 538 |
+
max_seqlen_k=max_seqlen_k,
|
| 539 |
+
is_causal=False,
|
| 540 |
+
sm_scale=self.scale,
|
| 541 |
+
smooth_k=True
|
| 542 |
+
).reshape(B, Lq, -1)
|
| 543 |
+
|
| 544 |
+
if target_dtype != q_compact.dtype:
|
| 545 |
+
oup = oup.float()
|
| 546 |
+
|
| 547 |
+
except Exception as e:
|
| 548 |
+
print(f"[WARNING] SageAttention failed ({str(e)[:100]}), falling back to FlashAttention/PyTorch")
|
| 549 |
+
oup = None
|
| 550 |
+
|
| 551 |
+
# Fall back to FlashAttention if SageAttention failed or not available
|
| 552 |
+
if oup is None and FLASH_ATTN_AVAILABLE:
|
| 553 |
+
try:
|
| 554 |
+
if q_compact.dtype == torch.float32:
|
| 555 |
+
oup = flash_attn_varlen_kvpacked_func(q=q_compact.to(dtype=torch.bfloat16), kv=kv_compact.to(dtype=torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
|
| 556 |
+
oup = oup.float()
|
| 557 |
+
else:
|
| 558 |
+
oup = flash_attn_varlen_kvpacked_func(q=q_compact, kv=kv_compact, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
|
| 559 |
+
except Exception as e:
|
| 560 |
+
print(f"[WARNING] FlashAttention failed ({str(e)[:100]}), falling back to PyTorch attention")
|
| 561 |
+
oup = None
|
| 562 |
+
|
| 563 |
+
# If both SageAttention and FlashAttention failed, fall back to PyTorch
|
| 564 |
+
if oup is None:
|
| 565 |
+
self.use_flash_attn = False # Disable optimized attention for future calls
|
| 566 |
+
|
| 567 |
+
# Fallback to PyTorch scaled_dot_product_attention
|
| 568 |
+
if not self.use_flash_attn:
|
| 569 |
+
# Unpack k and v from kv_compact: (N, 2, num_heads, head_dim)
|
| 570 |
+
k, v = kv_compact.unbind(dim=1) # k, v: (N, num_heads, head_dim)
|
| 571 |
+
|
| 572 |
+
# Reconstruct per-batch k and v tensors based on cu_seqlens_k
|
| 573 |
+
k_batched = []
|
| 574 |
+
v_batched = []
|
| 575 |
+
for i in range(B):
|
| 576 |
+
start = cu_seqlens_k[i].item()
|
| 577 |
+
end = cu_seqlens_k[i+1].item()
|
| 578 |
+
k_batched.append(k[start:end]) # (seq_len_i, num_heads, head_dim)
|
| 579 |
+
v_batched.append(v[start:end])
|
| 580 |
+
|
| 581 |
+
# Pad to max_seqlen_k for batching
|
| 582 |
+
k_padded = torch.stack([
|
| 583 |
+
F.pad(k_i, (0, 0, 0, 0, 0, max_seqlen_k - k_i.shape[0])) if k_i.shape[0] < max_seqlen_k else k_i
|
| 584 |
+
for k_i in k_batched
|
| 585 |
+
]) # (B, max_seqlen_k, num_heads, head_dim)
|
| 586 |
+
v_padded = torch.stack([
|
| 587 |
+
F.pad(v_i, (0, 0, 0, 0, 0, max_seqlen_k - v_i.shape[0])) if v_i.shape[0] < max_seqlen_k else v_i
|
| 588 |
+
for v_i in v_batched
|
| 589 |
+
]) # (B, max_seqlen_k, num_heads, head_dim)
|
| 590 |
+
|
| 591 |
+
# Reshape q_compact: (B*Lq, num_heads, head_dim) -> (B, Lq, num_heads, head_dim)
|
| 592 |
+
q_batched = q_compact.view(B, Lq, self.num_heads, self.head_dim)
|
| 593 |
+
|
| 594 |
+
# Transpose for attention: (B, num_heads, seq_len, head_dim)
|
| 595 |
+
q_attn = q_batched.transpose(1, 2) # (B, num_heads, Lq, head_dim)
|
| 596 |
+
k_attn = k_padded.transpose(1, 2) # (B, num_heads, max_seqlen_k, head_dim)
|
| 597 |
+
v_attn = v_padded.transpose(1, 2) # (B, num_heads, max_seqlen_k, head_dim)
|
| 598 |
+
|
| 599 |
+
# Create attention mask to mask out padding
|
| 600 |
+
attn_mask = torch.zeros(B, 1, Lq, max_seqlen_k, dtype=torch.bool, device=q_compact.device)
|
| 601 |
+
for i in range(B):
|
| 602 |
+
seq_len = cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item()
|
| 603 |
+
if seq_len < max_seqlen_k:
|
| 604 |
+
attn_mask[i, :, :, seq_len:] = True # Mask padding positions
|
| 605 |
+
|
| 606 |
+
# Apply attention
|
| 607 |
+
oup = slow_attn(
|
| 608 |
+
query=q_attn,
|
| 609 |
+
key=k_attn,
|
| 610 |
+
value=v_attn,
|
| 611 |
+
attn_mask=~attn_mask, # True = not masked, False = masked (inverted for PyTorch)
|
| 612 |
+
scale=self.scale,
|
| 613 |
+
dropout_p=0.0
|
| 614 |
+
) # (B, num_heads, Lq, head_dim)
|
| 615 |
+
|
| 616 |
+
# Reshape back: (B, num_heads, Lq, head_dim) -> (B, Lq, embed_dim)
|
| 617 |
+
oup = oup.transpose(1, 2).reshape(B, Lq, -1)
|
| 618 |
+
|
| 619 |
+
return self.proj_drop(self.proj(oup))
|
| 620 |
+
|
| 621 |
+
def extra_repr(self) -> str:
|
| 622 |
+
return f'Cq={self.embed_dim}, Ckv={self.kv_dim}, cos_attn={self.cos_attn}'
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class SelfAttnBlock(nn.Module):
|
| 626 |
+
def __init__(
|
| 627 |
+
self, embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
|
| 628 |
+
num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
|
| 629 |
+
swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
|
| 630 |
+
):
|
| 631 |
+
super(SelfAttnBlock, self).__init__()
|
| 632 |
+
self.C, self.D = embed_dim, cond_dim
|
| 633 |
+
self.drop_path_rate = drop_path
|
| 634 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 635 |
+
self.attn = SelfAttention(
|
| 636 |
+
embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn, attn_fn = attn_fn
|
| 637 |
+
)
|
| 638 |
+
self.using_swiglu = swiglu
|
| 639 |
+
self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
|
| 640 |
+
|
| 641 |
+
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
|
| 642 |
+
self.fused_norm_func = fused_norm_func
|
| 643 |
+
self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
|
| 644 |
+
|
| 645 |
+
self.shared_aln = shared_aln
|
| 646 |
+
if self.shared_aln:
|
| 647 |
+
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
|
| 648 |
+
else:
|
| 649 |
+
lin = nn.Linear(cond_dim, 6*embed_dim)
|
| 650 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
|
| 651 |
+
|
| 652 |
+
# NOTE: attn_bias_or_two_vector is None during inference
|
| 653 |
+
def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
|
| 654 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 655 |
+
if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
|
| 656 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
|
| 657 |
+
else:
|
| 658 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
|
| 659 |
+
|
| 660 |
+
if self.fused_ada_norm is None:
|
| 661 |
+
x = x + self.drop_path(self.attn( self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1))
|
| 662 |
+
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
|
| 663 |
+
else:
|
| 664 |
+
x = x + self.drop_path(self.attn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1), attn_bias_or_two_vector=attn_bias_or_two_vector).mul_(gamma1))
|
| 665 |
+
x = x + self.drop_path(self.ffn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
|
| 666 |
+
return x
|
| 667 |
+
|
| 668 |
+
def extra_repr(self) -> str:
|
| 669 |
+
return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}'
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
class CrossAttnBlock(nn.Module):
|
| 673 |
+
def __init__(
|
| 674 |
+
self,
|
| 675 |
+
embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
|
| 676 |
+
num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
|
| 677 |
+
swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
|
| 678 |
+
use_flex_attn=False, batch_size=2, pad_to_multiplier=1, apply_rope2d=False, rope2d_normalized_by_hw=False,
|
| 679 |
+
):
|
| 680 |
+
super(CrossAttnBlock, self).__init__()
|
| 681 |
+
self.C, self.D = embed_dim, cond_dim
|
| 682 |
+
self.drop_path_rate = drop_path
|
| 683 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 684 |
+
self.sa = SelfAttention(
|
| 685 |
+
embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn,
|
| 686 |
+
use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
|
| 687 |
+
)
|
| 688 |
+
self.ca = CrossAttention(embed_dim=embed_dim, kv_dim=kv_dim, num_heads=num_heads, proj_drop=drop, cos_attn=cos_attn)
|
| 689 |
+
self.using_swiglu = swiglu
|
| 690 |
+
self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
|
| 691 |
+
|
| 692 |
+
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
|
| 693 |
+
self.fused_norm_func = fused_norm_func
|
| 694 |
+
self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
|
| 695 |
+
self.ca_norm = norm_layer(embed_dim, elementwise_affine=True)
|
| 696 |
+
|
| 697 |
+
self.shared_aln = shared_aln
|
| 698 |
+
if self.shared_aln: # always True
|
| 699 |
+
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
|
| 700 |
+
else:
|
| 701 |
+
lin = nn.Linear(cond_dim, 6*embed_dim)
|
| 702 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
|
| 703 |
+
|
| 704 |
+
if cross_attn_layer_scale >= 0:
|
| 705 |
+
self.ca_gamma = nn.Parameter(cross_attn_layer_scale * torch.ones(embed_dim), requires_grad=True)
|
| 706 |
+
else:
|
| 707 |
+
self.ca_gamma = 1
|
| 708 |
+
|
| 709 |
+
self.checkpointing_sa_only = checkpointing_sa_only
|
| 710 |
+
|
| 711 |
+
# NOTE: attn_bias_or_two_vector is None during inference
|
| 712 |
+
def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
|
| 713 |
+
with torch.cuda.amp.autocast(enabled=False): # disable half precision
|
| 714 |
+
if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
|
| 715 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
|
| 716 |
+
else:
|
| 717 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
|
| 718 |
+
|
| 719 |
+
if self.fused_norm_func is None:
|
| 720 |
+
x_sa = self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1)
|
| 721 |
+
if self.checkpointing_sa_only and self.training:
|
| 722 |
+
x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
|
| 723 |
+
else:
|
| 724 |
+
x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
|
| 725 |
+
x = x + self.drop_path(x_sa.mul_(gamma1))
|
| 726 |
+
x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
|
| 727 |
+
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
|
| 728 |
+
else:
|
| 729 |
+
x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1)
|
| 730 |
+
if self.checkpointing_sa_only and self.training:
|
| 731 |
+
x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
|
| 732 |
+
else:
|
| 733 |
+
x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, scale_ind=scale_ind)
|
| 734 |
+
x = x + self.drop_path(x_sa.mul_(gamma1))
|
| 735 |
+
x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
|
| 736 |
+
x = x + self.drop_path(self.ffn(self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
|
| 737 |
+
return x
|
| 738 |
+
|
| 739 |
+
def extra_repr(self) -> str:
|
| 740 |
+
return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}, ca_gamma={"<learnable>" if isinstance(self.ca_gamma, nn.Parameter) else self.ca_gamma}'
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class AdaLNBeforeHead(nn.Module):
|
| 744 |
+
def __init__(self, C, D, act: bool, norm_layer: partial, fused_norm_func=None): # C: embed_dim, D: cond_dim
|
| 745 |
+
super().__init__()
|
| 746 |
+
self.C, self.D = C, D
|
| 747 |
+
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
|
| 748 |
+
self.fused_norm_func = fused_norm_func
|
| 749 |
+
self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
|
| 750 |
+
lin = nn.Linear(D, 2*C)
|
| 751 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
|
| 752 |
+
|
| 753 |
+
def forward(self, x_BLC: torch.Tensor, cond_BD: Optional[torch.Tensor]):
|
| 754 |
+
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
|
| 755 |
+
if self.fused_norm_func is None:
|
| 756 |
+
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
|
| 757 |
+
else:
|
| 758 |
+
return self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x_BLC, scale=scale, shift=shift)
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def main():
|
| 762 |
+
dev = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 763 |
+
rng = torch.Generator(device=dev)
|
| 764 |
+
# for Li in ([1, 3, 5], [1, 3]):
|
| 765 |
+
rng.manual_seed(0)
|
| 766 |
+
B, H, cq, ckv = 4, 8, 64, 96
|
| 767 |
+
Cq = H*cq
|
| 768 |
+
Ckv = H*ckv
|
| 769 |
+
|
| 770 |
+
Li = [5, 4, 7, 6]
|
| 771 |
+
Lq = 10
|
| 772 |
+
L = max(Li)
|
| 773 |
+
attn_bias = torch.zeros(B, 1, Lq, L, device=dev)
|
| 774 |
+
for i, x in enumerate(Li):
|
| 775 |
+
attn_bias[i, 0, :, x:] = -torch.inf
|
| 776 |
+
|
| 777 |
+
q = torch.randn(B, Lq, H, cq, generator=rng, device=dev)
|
| 778 |
+
k = torch.randn(B, L, H, ckv, generator=rng, device=dev)
|
| 779 |
+
v = torch.randn(B, L, H, ckv, generator=rng, device=dev)
|
| 780 |
+
tq, tk, tv = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # BHLc
|
| 781 |
+
|
| 782 |
+
seqlen_k = torch.tensor(Li, dtype=torch.int32, device=dev)
|
| 783 |
+
cu_seqlens_k = F.pad(torch.cumsum(seqlen_k, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 784 |
+
kv = torch.stack([k, v], dim=2)
|
| 785 |
+
kv_compact = torch.cat([kv[i, :Li[i]] for i in range(B)], dim=0)
|
| 786 |
+
|
| 787 |
+
ca = CrossAttention(for_attn_pool=False, embed_dim=Cq, kv_dim=Ckv, num_heads=H)
|
| 788 |
+
CrossAttention.forward
|
| 789 |
+
ca(q, (kv_compact, cu_seqlens_k, max(Li))).mean().backward()
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
if __name__ == '__main__':
|
| 793 |
+
main()
|
Infinity/infinity/models/infinity.py
ADDED
|
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Definition of Infinity transformer model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
from contextlib import nullcontext
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from timm.models import register_model
|
| 16 |
+
from torch.utils.checkpoint import checkpoint
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
import infinity.utils.dist as dist
|
| 21 |
+
from infinity.utils.dist import for_visualize
|
| 22 |
+
from infinity.models.basic import flash_attn_func, flash_fused_op_installed, AdaLNBeforeHead, CrossAttnBlock, SelfAttnBlock, CrossAttention, FastRMSNorm, precompute_rope2d_freqs_grid
|
| 23 |
+
from infinity.utils import misc
|
| 24 |
+
from infinity.models.flex_attn import FlexAttn
|
| 25 |
+
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from infinity.models.fused_op import fused_ada_layer_norm, fused_ada_rms_norm
|
| 29 |
+
except:
|
| 30 |
+
fused_ada_layer_norm, fused_ada_rms_norm = None, None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class MultiInpIdentity(nn.Module):
|
| 34 |
+
def forward(self, x, *args, **kwargs):
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TextAttentivePool(nn.Module):
|
| 39 |
+
def __init__(self, Ct5: int, D: int):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.Ct5, self.D = Ct5, D
|
| 42 |
+
if D > 4096:
|
| 43 |
+
self.head_dim = 64
|
| 44 |
+
else:
|
| 45 |
+
self.head_dim = 128
|
| 46 |
+
|
| 47 |
+
self.num_heads = Ct5 // self.head_dim
|
| 48 |
+
self.ca = CrossAttention(for_attn_pool=True, embed_dim=self.D, kv_dim=Ct5, num_heads=self.num_heads)
|
| 49 |
+
def forward(self, ca_kv):
|
| 50 |
+
return self.ca(None, ca_kv).squeeze(1)
|
| 51 |
+
|
| 52 |
+
class SharedAdaLin(nn.Linear):
|
| 53 |
+
def forward(self, cond_BD):
|
| 54 |
+
C = self.weight.shape[0] // 6
|
| 55 |
+
# Import get_weight_for_linear from basic.py
|
| 56 |
+
from infinity.models.basic import get_weight_for_linear
|
| 57 |
+
weight = get_weight_for_linear(self, target_dtype=cond_BD.dtype)
|
| 58 |
+
return F.linear(cond_BD, weight, self.bias).reshape(-1, 1, 6, C) # B16C
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MultipleLayers(nn.Module):
|
| 62 |
+
def __init__(self, ls, num_blocks_in_a_chunk, index):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.module = nn.ModuleList()
|
| 65 |
+
for i in range(index, index+num_blocks_in_a_chunk):
|
| 66 |
+
self.module.append(ls[i])
|
| 67 |
+
|
| 68 |
+
def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None):
|
| 69 |
+
h = x
|
| 70 |
+
for m in self.module:
|
| 71 |
+
if checkpointing_full_block:
|
| 72 |
+
h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
|
| 73 |
+
else:
|
| 74 |
+
h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
|
| 75 |
+
return h
|
| 76 |
+
|
| 77 |
+
class Infinity(nn.Module):
|
| 78 |
+
def __init__(
|
| 79 |
+
self, vae_local,
|
| 80 |
+
text_channels=0, text_maxlen=0, # text-cond generation
|
| 81 |
+
selecting_idx=None, # class-cond generation
|
| 82 |
+
embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture
|
| 83 |
+
drop_rate=0., drop_path_rate=0., # drop out and drop path
|
| 84 |
+
norm_eps=1e-6, rms_norm=False, # norm layer
|
| 85 |
+
shared_aln=False, head_aln=True, # adaptive norm
|
| 86 |
+
cond_drop_rate=0.1, # for classifier-free guidance
|
| 87 |
+
rand_uncond=False,
|
| 88 |
+
cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False,
|
| 89 |
+
raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
|
| 90 |
+
head_depth=1,
|
| 91 |
+
top_p=0.0, top_k=0.0,
|
| 92 |
+
customized_flash_attn=False, fused_mlp=False, fused_norm=False,
|
| 93 |
+
block_chunks=1,
|
| 94 |
+
checkpointing=None,
|
| 95 |
+
pad_to_multiplier=0,
|
| 96 |
+
use_flex_attn=False,
|
| 97 |
+
batch_size=2,
|
| 98 |
+
add_lvl_embeding_only_first_block=1,
|
| 99 |
+
use_bit_label=1,
|
| 100 |
+
rope2d_each_sa_layer=0,
|
| 101 |
+
rope2d_normalized_by_hw=0,
|
| 102 |
+
pn=None,
|
| 103 |
+
train_h_div_w_list=None,
|
| 104 |
+
video_frames=1,
|
| 105 |
+
always_training_scales=20,
|
| 106 |
+
apply_spatial_patchify = 0,
|
| 107 |
+
inference_mode=False,
|
| 108 |
+
):
|
| 109 |
+
# set hyperparameters
|
| 110 |
+
self.C = embed_dim
|
| 111 |
+
self.inference_mode = inference_mode
|
| 112 |
+
self.apply_spatial_patchify = apply_spatial_patchify
|
| 113 |
+
if self.apply_spatial_patchify:
|
| 114 |
+
self.d_vae = vae_local.embed_dim * 4
|
| 115 |
+
else:
|
| 116 |
+
self.d_vae = vae_local.embed_dim
|
| 117 |
+
self.use_bit_label = use_bit_label
|
| 118 |
+
self.codebook_dim = self.d_vae
|
| 119 |
+
self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size
|
| 120 |
+
self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None
|
| 121 |
+
self.Ct5 = text_channels
|
| 122 |
+
self.depth = depth
|
| 123 |
+
self.num_heads = num_heads
|
| 124 |
+
self.batch_size = batch_size
|
| 125 |
+
self.mlp_ratio = mlp_ratio
|
| 126 |
+
self.cond_drop_rate = cond_drop_rate
|
| 127 |
+
self.norm_eps = norm_eps
|
| 128 |
+
self.prog_si = -1
|
| 129 |
+
self.pn = pn
|
| 130 |
+
self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates
|
| 131 |
+
self.video_frames = video_frames
|
| 132 |
+
self.always_training_scales = always_training_scales
|
| 133 |
+
|
| 134 |
+
assert add_lvl_embeding_only_first_block in [0,1]
|
| 135 |
+
self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block
|
| 136 |
+
assert rope2d_each_sa_layer in [0,1]
|
| 137 |
+
self.rope2d_each_sa_layer = rope2d_each_sa_layer
|
| 138 |
+
self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
|
| 139 |
+
print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \
|
| 140 |
+
self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}')
|
| 141 |
+
head_up_method = ''
|
| 142 |
+
word_patch_size = 1 if head_up_method in {'', 'no'} else 2
|
| 143 |
+
if word_patch_size > 1:
|
| 144 |
+
assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}'
|
| 145 |
+
|
| 146 |
+
self.checkpointing = checkpointing
|
| 147 |
+
self.pad_to_multiplier = max(1, pad_to_multiplier)
|
| 148 |
+
|
| 149 |
+
customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames)
|
| 150 |
+
self.customized_flash_attn = customized_flash_attn and customized_kernel_installed
|
| 151 |
+
if customized_flash_attn and not customized_kernel_installed:
|
| 152 |
+
import inspect, warnings
|
| 153 |
+
file_path = inspect.getsourcefile(flash_attn_func)
|
| 154 |
+
line_number = inspect.getsourcelines(flash_attn_func)[1]
|
| 155 |
+
info = (
|
| 156 |
+
f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n'
|
| 157 |
+
f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n'
|
| 158 |
+
f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n'
|
| 159 |
+
)
|
| 160 |
+
warnings.warn(info, ImportWarning)
|
| 161 |
+
print(info, flush=True)
|
| 162 |
+
|
| 163 |
+
self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying
|
| 164 |
+
self.first_l = 1
|
| 165 |
+
# solve top-p top-k sampling hyperparameters
|
| 166 |
+
self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k))
|
| 167 |
+
if self.top_p < 1e-5: self.top_p = 0
|
| 168 |
+
if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0
|
| 169 |
+
|
| 170 |
+
t = torch.zeros(dist.get_world_size(), device=dist.get_device())
|
| 171 |
+
t[dist.get_rank()] = float(flash_fused_op_installed)
|
| 172 |
+
dist.barrier()
|
| 173 |
+
dist.allreduce(t)
|
| 174 |
+
assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}'
|
| 175 |
+
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.rng = torch.Generator(device=dist.get_device())
|
| 178 |
+
self.maybe_record_function = nullcontext
|
| 179 |
+
self.text_maxlen = text_maxlen
|
| 180 |
+
self.t2i = text_channels != 0
|
| 181 |
+
|
| 182 |
+
# [inp & position embedding]
|
| 183 |
+
init_std = math.sqrt(1 / self.C / 3)
|
| 184 |
+
self.norm0_cond = nn.Identity()
|
| 185 |
+
if self.t2i:
|
| 186 |
+
self.selecting_idx = None
|
| 187 |
+
self.num_classes = 0
|
| 188 |
+
self.D = self.C
|
| 189 |
+
|
| 190 |
+
cfg_uncond = torch.empty(self.text_maxlen, self.Ct5)
|
| 191 |
+
rng = torch.Generator(device='cpu')
|
| 192 |
+
rng.manual_seed(0)
|
| 193 |
+
torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng)
|
| 194 |
+
cfg_uncond /= self.Ct5 ** 0.5
|
| 195 |
+
if rand_uncond:
|
| 196 |
+
self.register_buffer('cfg_uncond', cfg_uncond)
|
| 197 |
+
else:
|
| 198 |
+
self.cfg_uncond = nn.Parameter(cfg_uncond)
|
| 199 |
+
|
| 200 |
+
self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps)
|
| 201 |
+
self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D)
|
| 202 |
+
self.text_proj_for_ca = nn.Sequential(
|
| 203 |
+
nn.Linear(self.Ct5, self.D),
|
| 204 |
+
nn.GELU(approximate='tanh'),
|
| 205 |
+
nn.Linear(self.D, self.D),
|
| 206 |
+
)
|
| 207 |
+
else: # class-label cond
|
| 208 |
+
if selecting_idx is None:
|
| 209 |
+
num_classes = 1000
|
| 210 |
+
print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======')
|
| 211 |
+
selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device())
|
| 212 |
+
self.selecting_idx = selecting_idx
|
| 213 |
+
self.num_classes = selecting_idx.shape[-1]
|
| 214 |
+
self.D = self.C
|
| 215 |
+
self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
|
| 216 |
+
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
|
| 217 |
+
|
| 218 |
+
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
|
| 219 |
+
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
|
| 220 |
+
if self.rope2d_each_sa_layer:
|
| 221 |
+
rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw)
|
| 222 |
+
self.rope2d_freqs_grid = rope2d_freqs_grid
|
| 223 |
+
else:
|
| 224 |
+
raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented')
|
| 225 |
+
self.lvl_embed = nn.Embedding(15, self.C)
|
| 226 |
+
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
|
| 227 |
+
|
| 228 |
+
# [input layers] input norm && input embedding
|
| 229 |
+
norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps)
|
| 230 |
+
self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity()
|
| 231 |
+
self.word_embed = nn.Linear(self.d_vae, self.C)
|
| 232 |
+
|
| 233 |
+
# [shared adaptive layernorm mapping network]
|
| 234 |
+
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
|
| 235 |
+
|
| 236 |
+
# fused norm
|
| 237 |
+
if fused_norm:
|
| 238 |
+
fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm
|
| 239 |
+
if fused_norm_func is not None: # pre-compile
|
| 240 |
+
B = 2
|
| 241 |
+
x = torch.randn(B, 1, self.C).requires_grad_(True)
|
| 242 |
+
scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
|
| 243 |
+
shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
|
| 244 |
+
# fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward()
|
| 245 |
+
del B, x, scale, shift
|
| 246 |
+
else:
|
| 247 |
+
fused_norm_func = None
|
| 248 |
+
|
| 249 |
+
# [backbone and head]
|
| 250 |
+
self.use_flex_attn = use_flex_attn
|
| 251 |
+
self.attn_fn_compile_dict = {}
|
| 252 |
+
self.batch_size = batch_size
|
| 253 |
+
if self.use_flex_attn:
|
| 254 |
+
self.attn_fn_compile_dict = self.compile_flex_attn()
|
| 255 |
+
|
| 256 |
+
self.drop_path_rate = drop_path_rate
|
| 257 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing)
|
| 258 |
+
self.unregistered_blocks = []
|
| 259 |
+
for block_idx in range(depth):
|
| 260 |
+
block = (CrossAttnBlock if self.t2i else SelfAttnBlock)(
|
| 261 |
+
embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer,
|
| 262 |
+
num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn,
|
| 263 |
+
swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func,
|
| 264 |
+
checkpointing_sa_only=self.checkpointing == 'self-attn',
|
| 265 |
+
use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
|
| 266 |
+
)
|
| 267 |
+
self.unregistered_blocks.append(block)
|
| 268 |
+
|
| 269 |
+
# [head]
|
| 270 |
+
V = self.V
|
| 271 |
+
if head_aln:
|
| 272 |
+
self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func)
|
| 273 |
+
self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
|
| 274 |
+
else:
|
| 275 |
+
self.head_nm = MultiInpIdentity()
|
| 276 |
+
self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
|
| 277 |
+
|
| 278 |
+
self.num_block_chunks = block_chunks or 1
|
| 279 |
+
self.num_blocks_in_a_chunk = depth // block_chunks
|
| 280 |
+
print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}")
|
| 281 |
+
assert self.num_blocks_in_a_chunk * block_chunks == depth
|
| 282 |
+
if self.num_block_chunks == 1:
|
| 283 |
+
self.blocks = nn.ModuleList(self.unregistered_blocks)
|
| 284 |
+
else:
|
| 285 |
+
self.block_chunks = nn.ModuleList()
|
| 286 |
+
for i in range(self.num_block_chunks):
|
| 287 |
+
self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk))
|
| 288 |
+
print(
|
| 289 |
+
f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n'
|
| 290 |
+
f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n'
|
| 291 |
+
f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
|
| 292 |
+
end='\n\n', flush=True
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def compile_flex_attn(self):
|
| 297 |
+
attn_fn_compile_dict = {}
|
| 298 |
+
for h_div_w in self.train_h_div_w_list:
|
| 299 |
+
h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))]
|
| 300 |
+
full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales']
|
| 301 |
+
if self.inference_mode:
|
| 302 |
+
apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule)))
|
| 303 |
+
mask_type = "infinity_infer_mask_with_kv_cache"
|
| 304 |
+
auto_padding = True
|
| 305 |
+
else:
|
| 306 |
+
mask_type = 'var'
|
| 307 |
+
auto_padding = False
|
| 308 |
+
apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))]
|
| 309 |
+
for scales_num in apply_flex_attn_scales:
|
| 310 |
+
print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======')
|
| 311 |
+
scale_schedule = full_scale_schedule[:scales_num]
|
| 312 |
+
scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule]
|
| 313 |
+
patchs_nums_tuple = tuple(scale_schedule)
|
| 314 |
+
SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple)
|
| 315 |
+
aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
|
| 316 |
+
attn_fn = FlexAttn(block_scales = patchs_nums_tuple,
|
| 317 |
+
mask_type = mask_type,
|
| 318 |
+
B = self.batch_size,
|
| 319 |
+
H = self.num_heads,
|
| 320 |
+
L = aligned_L,
|
| 321 |
+
auto_padding=auto_padding)
|
| 322 |
+
attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
|
| 323 |
+
|
| 324 |
+
if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos)
|
| 325 |
+
scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule]
|
| 326 |
+
patchs_nums_tuple = tuple(scale_schedule)
|
| 327 |
+
SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple)
|
| 328 |
+
aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
|
| 329 |
+
attn_fn = FlexAttn(block_scales = patchs_nums_tuple,
|
| 330 |
+
mask_type = mask_type,
|
| 331 |
+
B = self.batch_size,
|
| 332 |
+
H = self.num_heads,
|
| 333 |
+
L = aligned_L)
|
| 334 |
+
attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
|
| 335 |
+
return attn_fn_compile_dict
|
| 336 |
+
|
| 337 |
+
def _apply_module_with_dtype_handling(self, module, x):
|
| 338 |
+
"""
|
| 339 |
+
Apply a module (Linear, Sequential, etc.) with F16 weight dtype handling.
|
| 340 |
+
"""
|
| 341 |
+
from infinity.models.basic import get_weight_for_linear
|
| 342 |
+
|
| 343 |
+
if isinstance(module, nn.Linear):
|
| 344 |
+
# Handle Linear layer with dtype conversion
|
| 345 |
+
weight = get_weight_for_linear(module, target_dtype=x.dtype)
|
| 346 |
+
return F.linear(x, weight, module.bias)
|
| 347 |
+
elif isinstance(module, nn.Sequential):
|
| 348 |
+
# Recursively apply each layer in the sequential
|
| 349 |
+
for layer in module:
|
| 350 |
+
x = self._apply_module_with_dtype_handling(layer, x)
|
| 351 |
+
return x
|
| 352 |
+
else:
|
| 353 |
+
# For other modules (GELU, LayerNorm, etc.), apply directly
|
| 354 |
+
return module(x)
|
| 355 |
+
|
| 356 |
+
def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]):
|
| 357 |
+
"""
|
| 358 |
+
:param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim)
|
| 359 |
+
:param cond_BD: shaped (B or batch_size, D or cond_dim)
|
| 360 |
+
:param tau: temperature
|
| 361 |
+
:return: logits, shaped (B or batch_size, V or vocabulary_size)
|
| 362 |
+
"""
|
| 363 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 364 |
+
x = self.head_nm(h.float(), cond_BD.float())
|
| 365 |
+
return self._apply_module_with_dtype_handling(self.head, x)
|
| 366 |
+
|
| 367 |
+
def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0):
|
| 368 |
+
bs, seq_len, c = feature.shape
|
| 369 |
+
patch_t, patch_h, patch_w = scale_schedule[scale_ind]
|
| 370 |
+
t_mul_h_mul_w = patch_t * patch_h * patch_w
|
| 371 |
+
assert t_mul_h_mul_w + need_to_pad == seq_len
|
| 372 |
+
feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device))
|
| 373 |
+
return feature
|
| 374 |
+
|
| 375 |
+
def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0):
|
| 376 |
+
ptr = 0
|
| 377 |
+
x_BLC_list = []
|
| 378 |
+
for scale_ind, patch_t_h_w in enumerate(scale_schedule):
|
| 379 |
+
scale_seq_len = np.array(patch_t_h_w).prod()
|
| 380 |
+
x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c]
|
| 381 |
+
ptr += scale_seq_len
|
| 382 |
+
x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule)
|
| 383 |
+
x_BLC_list.append(x_BLC_this_scale)
|
| 384 |
+
assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}'
|
| 385 |
+
x_BLC_list.append(x_BLC[:,ptr:])
|
| 386 |
+
x_BLC = torch.cat(x_BLC_list, dim=1)
|
| 387 |
+
return x_BLC
|
| 388 |
+
|
| 389 |
+
def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]],
|
| 390 |
+
cfg_infer=False,
|
| 391 |
+
**kwargs,
|
| 392 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV
|
| 393 |
+
"""
|
| 394 |
+
label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k)
|
| 395 |
+
:return: logits BLV, V is vocab_size
|
| 396 |
+
"""
|
| 397 |
+
if cfg_infer:
|
| 398 |
+
return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs)
|
| 399 |
+
|
| 400 |
+
x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32
|
| 401 |
+
B = x_BLC_wo_prefix.shape[0]
|
| 402 |
+
|
| 403 |
+
# [1. get input sequence x_BLC]
|
| 404 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 405 |
+
kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
|
| 406 |
+
# drop cond
|
| 407 |
+
total = 0
|
| 408 |
+
for le in lens:
|
| 409 |
+
if random.random() < self.cond_drop_rate:
|
| 410 |
+
kv_compact[total:total+le] = self.cfg_uncond[:le]
|
| 411 |
+
total += le
|
| 412 |
+
must_on_graph = self.cfg_uncond[0, 0] * 0
|
| 413 |
+
kv_compact = self.text_norm(kv_compact).contiguous()
|
| 414 |
+
sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32
|
| 415 |
+
kv_compact = self.text_proj_for_ca(kv_compact).contiguous()
|
| 416 |
+
kv_compact[0, 0] += must_on_graph
|
| 417 |
+
ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
|
| 418 |
+
|
| 419 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32
|
| 420 |
+
|
| 421 |
+
sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1)
|
| 422 |
+
x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1)
|
| 423 |
+
|
| 424 |
+
# [1.1. pad the seqlen dim]
|
| 425 |
+
l_end = x_BLC.shape[1]
|
| 426 |
+
need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0
|
| 427 |
+
|
| 428 |
+
if self.customized_flash_attn:
|
| 429 |
+
Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end]
|
| 430 |
+
Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end]
|
| 431 |
+
attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen)
|
| 432 |
+
# todo: solve need_to_pad here
|
| 433 |
+
elif self.use_flex_attn:
|
| 434 |
+
if need_to_pad:
|
| 435 |
+
x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
|
| 436 |
+
assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0'
|
| 437 |
+
attn_bias_or_two_vector = None
|
| 438 |
+
else:
|
| 439 |
+
d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1)
|
| 440 |
+
dT = d.transpose(1, 2) # dT: 11L
|
| 441 |
+
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end)
|
| 442 |
+
attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL
|
| 443 |
+
if need_to_pad:
|
| 444 |
+
attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf)
|
| 445 |
+
attn_bias[0, 0, l_end:, 0] = 0
|
| 446 |
+
x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
|
| 447 |
+
attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device)
|
| 448 |
+
|
| 449 |
+
if self.use_flex_attn:
|
| 450 |
+
attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)]
|
| 451 |
+
else:
|
| 452 |
+
attn_fn = None
|
| 453 |
+
|
| 454 |
+
# [2. block loop]
|
| 455 |
+
SelfAttnBlock.forward, CrossAttnBlock.forward
|
| 456 |
+
checkpointing_full_block = self.checkpointing == 'full-block' and self.training
|
| 457 |
+
if self.num_block_chunks == 1:
|
| 458 |
+
for i, b in enumerate(self.blocks):
|
| 459 |
+
if self.add_lvl_embeding_only_first_block and i == 0:
|
| 460 |
+
x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
|
| 461 |
+
if not self.add_lvl_embeding_only_first_block:
|
| 462 |
+
x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
|
| 463 |
+
if checkpointing_full_block:
|
| 464 |
+
x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False)
|
| 465 |
+
else:
|
| 466 |
+
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid)
|
| 467 |
+
else:
|
| 468 |
+
for i, chunk in enumerate(self.block_chunks): # this path
|
| 469 |
+
if self.add_lvl_embeding_only_first_block and i == 0:
|
| 470 |
+
x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
|
| 471 |
+
if not self.add_lvl_embeding_only_first_block:
|
| 472 |
+
x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
|
| 473 |
+
x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid)
|
| 474 |
+
|
| 475 |
+
# [3. unpad the seqlen dim, and then get logits]
|
| 476 |
+
return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size
|
| 477 |
+
|
| 478 |
+
@torch.no_grad()
|
| 479 |
+
def autoregressive_infer_cfg(
|
| 480 |
+
self,
|
| 481 |
+
vae=None,
|
| 482 |
+
scale_schedule=None,
|
| 483 |
+
label_B_or_BLT=None,
|
| 484 |
+
B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None,
|
| 485 |
+
g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0,
|
| 486 |
+
returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False,
|
| 487 |
+
cfg_exp_k: float=0.0, cfg_insertion_layer=[-5],
|
| 488 |
+
vae_type=0, softmax_merge_topk=-1, ret_img=False,
|
| 489 |
+
trunk_scale=1000,
|
| 490 |
+
gt_leak=0, gt_ls_Bl=None,
|
| 491 |
+
inference_mode=False,
|
| 492 |
+
save_img_path=None,
|
| 493 |
+
sampling_per_bits=1,
|
| 494 |
+
): # returns List[idx_Bl]
|
| 495 |
+
if g_seed is None: rng = None
|
| 496 |
+
else: self.rng.manual_seed(g_seed); rng = self.rng
|
| 497 |
+
assert len(cfg_list) >= len(scale_schedule)
|
| 498 |
+
assert len(tau_list) >= len(scale_schedule)
|
| 499 |
+
|
| 500 |
+
# scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify,
|
| 501 |
+
# we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w
|
| 502 |
+
if self.apply_spatial_patchify:
|
| 503 |
+
vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
|
| 504 |
+
else:
|
| 505 |
+
vae_scale_schedule = scale_schedule
|
| 506 |
+
|
| 507 |
+
kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
|
| 508 |
+
if any(np.array(cfg_list) != 1):
|
| 509 |
+
bs = 2*B
|
| 510 |
+
if not negative_label_B_or_BLT:
|
| 511 |
+
kv_compact_un = kv_compact.clone()
|
| 512 |
+
total = 0
|
| 513 |
+
for le in lens:
|
| 514 |
+
kv_compact_un[total:total+le] = (self.cfg_uncond)[:le]
|
| 515 |
+
total += le
|
| 516 |
+
kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
|
| 517 |
+
cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0)
|
| 518 |
+
else:
|
| 519 |
+
kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT
|
| 520 |
+
kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
|
| 521 |
+
cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0)
|
| 522 |
+
max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un)
|
| 523 |
+
else:
|
| 524 |
+
bs = B
|
| 525 |
+
|
| 526 |
+
kv_compact = self.text_norm(kv_compact)
|
| 527 |
+
sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096]
|
| 528 |
+
kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096]
|
| 529 |
+
ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
|
| 530 |
+
last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1)
|
| 531 |
+
|
| 532 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 533 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous()
|
| 534 |
+
accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images
|
| 535 |
+
idx_Bl_list, idx_Bld_list = [], []
|
| 536 |
+
|
| 537 |
+
if inference_mode:
|
| 538 |
+
for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True)
|
| 539 |
+
else:
|
| 540 |
+
assert self.num_block_chunks > 1
|
| 541 |
+
for block_chunk_ in self.block_chunks:
|
| 542 |
+
for module in block_chunk_.module.module:
|
| 543 |
+
(module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True)
|
| 544 |
+
|
| 545 |
+
abs_cfg_insertion_layers = []
|
| 546 |
+
add_cfg_on_logits, add_cfg_on_probs = False, False
|
| 547 |
+
leng = len(self.unregistered_blocks)
|
| 548 |
+
for item in cfg_insertion_layer:
|
| 549 |
+
if item == 0: # add cfg on logits
|
| 550 |
+
add_cfg_on_logits = True
|
| 551 |
+
elif item == 1: # add cfg on probs
|
| 552 |
+
add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs
|
| 553 |
+
elif item < 0: # determine to add cfg at item-th layer's output
|
| 554 |
+
assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}'
|
| 555 |
+
abs_cfg_insertion_layers.append(leng+item)
|
| 556 |
+
else:
|
| 557 |
+
raise ValueError(f'cfg_insertion_layer: {item} is not valid')
|
| 558 |
+
|
| 559 |
+
num_stages_minus_1 = len(scale_schedule)-1
|
| 560 |
+
summed_codes = 0
|
| 561 |
+
for si, pn in enumerate(scale_schedule): # si: i-th segment
|
| 562 |
+
cfg = cfg_list[si]
|
| 563 |
+
if si >= trunk_scale:
|
| 564 |
+
break
|
| 565 |
+
cur_L += np.array(pn).prod()
|
| 566 |
+
|
| 567 |
+
need_to_pad = 0
|
| 568 |
+
attn_fn = None
|
| 569 |
+
if self.use_flex_attn:
|
| 570 |
+
# need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier
|
| 571 |
+
# if need_to_pad:
|
| 572 |
+
# last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad))
|
| 573 |
+
attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None)
|
| 574 |
+
|
| 575 |
+
# assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
|
| 576 |
+
layer_idx = 0
|
| 577 |
+
for block_idx, b in enumerate(self.block_chunks):
|
| 578 |
+
# last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int
|
| 579 |
+
if self.add_lvl_embeding_only_first_block and block_idx == 0:
|
| 580 |
+
last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
|
| 581 |
+
if not self.add_lvl_embeding_only_first_block:
|
| 582 |
+
last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
|
| 583 |
+
|
| 584 |
+
for m in b.module:
|
| 585 |
+
last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si)
|
| 586 |
+
if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers):
|
| 587 |
+
# print(f'add cfg={cfg} on {layer_idx}-th layer output')
|
| 588 |
+
last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:]
|
| 589 |
+
last_stage = torch.cat((last_stage, last_stage), 0)
|
| 590 |
+
layer_idx += 1
|
| 591 |
+
|
| 592 |
+
if (cfg != 1) and add_cfg_on_logits:
|
| 593 |
+
# print(f'add cfg on add_cfg_on_logits')
|
| 594 |
+
logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si])
|
| 595 |
+
logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:]
|
| 596 |
+
else:
|
| 597 |
+
logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si])
|
| 598 |
+
|
| 599 |
+
if self.use_bit_label:
|
| 600 |
+
tmp_bs, tmp_seq_len = logits_BlV.shape[:2]
|
| 601 |
+
logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2)
|
| 602 |
+
idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
|
| 603 |
+
idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1)
|
| 604 |
+
else:
|
| 605 |
+
idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
|
| 606 |
+
if vae_type != 0:
|
| 607 |
+
assert returns_vemb
|
| 608 |
+
if si < gt_leak:
|
| 609 |
+
idx_Bld = gt_ls_Bl[si]
|
| 610 |
+
else:
|
| 611 |
+
assert pn[0] == 1
|
| 612 |
+
idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d]
|
| 613 |
+
if self.apply_spatial_patchify: # unpatchify operation
|
| 614 |
+
idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w]
|
| 615 |
+
idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w]
|
| 616 |
+
idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d]
|
| 617 |
+
idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d]
|
| 618 |
+
|
| 619 |
+
idx_Bld_list.append(idx_Bld)
|
| 620 |
+
codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
|
| 621 |
+
if si != num_stages_minus_1:
|
| 622 |
+
summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up)
|
| 623 |
+
last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_up) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
|
| 624 |
+
last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w]
|
| 625 |
+
if self.apply_spatial_patchify: # patchify operation
|
| 626 |
+
last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w]
|
| 627 |
+
last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w]
|
| 628 |
+
last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d]
|
| 629 |
+
else:
|
| 630 |
+
summed_codes += codes
|
| 631 |
+
else:
|
| 632 |
+
if si < gt_leak:
|
| 633 |
+
idx_Bl = gt_ls_Bl[si]
|
| 634 |
+
h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC
|
| 635 |
+
|
| 636 |
+
# h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1])
|
| 637 |
+
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2])
|
| 638 |
+
ret.append(h_BChw if returns_vemb != 0 else idx_Bl)
|
| 639 |
+
idx_Bl_list.append(idx_Bl)
|
| 640 |
+
if si != num_stages_minus_1:
|
| 641 |
+
accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule)
|
| 642 |
+
|
| 643 |
+
if si != num_stages_minus_1:
|
| 644 |
+
last_stage = self.word_embed(self.norm0_ve(last_stage))
|
| 645 |
+
last_stage = last_stage.repeat(bs//B, 1, 1)
|
| 646 |
+
|
| 647 |
+
if inference_mode:
|
| 648 |
+
for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False)
|
| 649 |
+
else:
|
| 650 |
+
assert self.num_block_chunks > 1
|
| 651 |
+
for block_chunk_ in self.block_chunks:
|
| 652 |
+
for module in block_chunk_.module.module:
|
| 653 |
+
(module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False)
|
| 654 |
+
|
| 655 |
+
if not ret_img:
|
| 656 |
+
return ret, idx_Bl_list, []
|
| 657 |
+
|
| 658 |
+
if vae_type != 0:
|
| 659 |
+
img = vae.decode(summed_codes.squeeze(-3))
|
| 660 |
+
else:
|
| 661 |
+
img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True)
|
| 662 |
+
|
| 663 |
+
img = (img + 1) / 2
|
| 664 |
+
img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,))
|
| 665 |
+
return ret, idx_Bl_list, img
|
| 666 |
+
|
| 667 |
+
@for_visualize
|
| 668 |
+
def vis_key_params(self, ep):
|
| 669 |
+
return
|
| 670 |
+
|
| 671 |
+
def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False):
|
| 672 |
+
for k in state_dict:
|
| 673 |
+
if 'cfg_uncond' in k:
|
| 674 |
+
old, new = state_dict[k], self.cfg_uncond.data
|
| 675 |
+
min_tlen = min(old.shape[0], new.shape[0])
|
| 676 |
+
if min_tlen == old.shape[0]:
|
| 677 |
+
state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:]))
|
| 678 |
+
else:
|
| 679 |
+
state_dict[k] = old[:min_tlen]
|
| 680 |
+
|
| 681 |
+
for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'):
|
| 682 |
+
state_dict.pop(buf_name, None)
|
| 683 |
+
if hasattr(self, buf_name):
|
| 684 |
+
state_dict[buf_name] = getattr(self, buf_name)
|
| 685 |
+
|
| 686 |
+
return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
|
| 687 |
+
|
| 688 |
+
def special_init(
|
| 689 |
+
self,
|
| 690 |
+
aln_init: float,
|
| 691 |
+
aln_gamma_init: float,
|
| 692 |
+
scale_head: float,
|
| 693 |
+
scale_proj: int,
|
| 694 |
+
):
|
| 695 |
+
# init head's norm
|
| 696 |
+
if isinstance(self.head_nm, AdaLNBeforeHead):
|
| 697 |
+
self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head
|
| 698 |
+
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
|
| 699 |
+
self.head_nm.ada_lin[-1].bias.data.zero_()
|
| 700 |
+
|
| 701 |
+
# init head's proj
|
| 702 |
+
if scale_head >= 0:
|
| 703 |
+
if isinstance(self.head, nn.Linear):
|
| 704 |
+
self.head.weight.data.mul_(scale_head)
|
| 705 |
+
self.head.bias.data.zero_()
|
| 706 |
+
elif isinstance(self.head, nn.Sequential):
|
| 707 |
+
self.head[-1].weight.data.mul_(scale_head)
|
| 708 |
+
self.head[-1].bias.data.zero_()
|
| 709 |
+
|
| 710 |
+
depth = len(self.unregistered_blocks)
|
| 711 |
+
for block_idx, sab in enumerate(self.unregistered_blocks):
|
| 712 |
+
sab: Union[SelfAttnBlock, CrossAttnBlock]
|
| 713 |
+
# init proj
|
| 714 |
+
scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx))
|
| 715 |
+
if scale_proj == 1:
|
| 716 |
+
if self.t2i:
|
| 717 |
+
sab.sa.proj.weight.data.mul_(scale)
|
| 718 |
+
sab.ca.proj.weight.data.mul_(scale)
|
| 719 |
+
else:
|
| 720 |
+
sab.attn.proj.weight.data.mul_(scale)
|
| 721 |
+
sab.ffn.fc2.weight.data.mul_(scale)
|
| 722 |
+
# if sab.using_swiglu:
|
| 723 |
+
# nn.init.ones_(sab.ffn.fcg.bias)
|
| 724 |
+
# nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
|
| 725 |
+
|
| 726 |
+
# init ada_lin
|
| 727 |
+
if hasattr(sab, 'ada_lin'):
|
| 728 |
+
lin = sab.ada_lin[-1]
|
| 729 |
+
lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma
|
| 730 |
+
lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift
|
| 731 |
+
if hasattr(lin, 'bias') and lin.bias is not None:
|
| 732 |
+
lin.bias.data.zero_()
|
| 733 |
+
elif hasattr(sab, 'ada_gss'):
|
| 734 |
+
sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma
|
| 735 |
+
sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift
|
| 736 |
+
|
| 737 |
+
def extra_repr(self):
|
| 738 |
+
return f'drop_path_rate={self.drop_path_rate}'
|
| 739 |
+
|
| 740 |
+
def get_layer_id_and_scale_exp(self, para_name: str):
|
| 741 |
+
raise NotImplementedError
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
|
| 745 |
+
B, l, V = logits_BlV.shape
|
| 746 |
+
if top_k > 0:
|
| 747 |
+
top_k = min(top_k, V)
|
| 748 |
+
idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
|
| 749 |
+
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
|
| 750 |
+
if top_p > 0:
|
| 751 |
+
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
|
| 752 |
+
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
|
| 753 |
+
sorted_idx_to_remove[..., -1:] = False
|
| 754 |
+
logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
|
| 755 |
+
# sample (have to squeeze cuz multinomial can only be used on 2D tensor)
|
| 756 |
+
replacement = num_samples >= 0
|
| 757 |
+
num_samples = abs(num_samples)
|
| 758 |
+
return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
|
| 759 |
+
|
| 760 |
+
def sampling_with_top_k_top_p_also_inplace_modifying_probs_(probs_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
|
| 761 |
+
B, l, V = probs_BlV.shape
|
| 762 |
+
if top_k > 0:
|
| 763 |
+
top_k = min(top_k, V)
|
| 764 |
+
idx_to_remove = probs_BlV < probs_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
|
| 765 |
+
probs_BlV.masked_fill_(idx_to_remove, 0)
|
| 766 |
+
if top_p > 0:
|
| 767 |
+
sorted_probs, sorted_idx = probs_BlV.sort(dim=-1, descending=False)
|
| 768 |
+
sorted_idx_to_remove = sorted_probs.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
|
| 769 |
+
sorted_idx_to_remove[..., -1:] = False
|
| 770 |
+
probs_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), 0)
|
| 771 |
+
# sample (have to squeeze cuz multinomial can only be used on 2D tensor)
|
| 772 |
+
probs_BlV = probs_BlV / probs_BlV.sum(-1, keepdims=True)
|
| 773 |
+
replacement = num_samples >= 0
|
| 774 |
+
num_samples = abs(num_samples)
|
| 775 |
+
return torch.multinomial(probs_BlV.view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def get_params_num(d, w, mlp):
|
| 779 |
+
m = round(mlp * w / 256) * 256
|
| 780 |
+
s = d * (w**2 * 8 + w*m * 2) # sa+ca, mlp
|
| 781 |
+
s += w**2 * 6 # saln
|
| 782 |
+
s += 4096 * w # pred
|
| 783 |
+
s += 32 * w # we
|
| 784 |
+
|
| 785 |
+
Ct5 = 4096
|
| 786 |
+
s += Ct5*w * 4 # T5 attn pool
|
| 787 |
+
s += Ct5*w + w*w # T5 mlp
|
| 788 |
+
return f'{s/1e9:.2f}B'
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
TIMM_KEYS = {'img_size', 'pretrained', 'pretrained_cfg', 'pretrained_cfg_overlay', 'global_pool'}
|
| 792 |
+
|
| 793 |
+
@register_model
|
| 794 |
+
def infinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
| 795 |
+
|
| 796 |
+
@register_model
|
| 797 |
+
def infinity_20b(depth=58, embed_dim=4608, num_heads=4608//128, drop_path_rate=0.25, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
| 798 |
+
|
| 799 |
+
# model configuration for scaling Infinity transformer
|
| 800 |
+
@register_model
|
| 801 |
+
def infinity_layer12(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, **kwargs):
|
| 802 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
| 803 |
+
@register_model
|
| 804 |
+
def infinity_layer16(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, **kwargs):
|
| 805 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
| 806 |
+
@register_model
|
| 807 |
+
def infinity_layer24(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, **kwargs):
|
| 808 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
| 809 |
+
@register_model
|
| 810 |
+
def infinity_layer32(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, **kwargs):
|
| 811 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
| 812 |
+
@register_model
|
| 813 |
+
def infinity_layer40(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, **kwargs):
|
| 814 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
| 815 |
+
@register_model
|
| 816 |
+
def infinity_layer48(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, **kwargs):
|
| 817 |
+
return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
|
Infinity/infinity_vae_d32_reg.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a37fa3ea1b2a1ebd23de61d91a5e68202825e5a67edaef4b7c55f5fd5b9cf26
|
| 3 |
+
size 1557324701
|
README.md
CHANGED
|
@@ -1,3 +1,162 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Infinity-2B GGUF with SageAttention
|
| 2 |
+
|
| 3 |
+
Unofficial Q8_0 GGUF quantization of Infinity-2B with **SageAttention** support for even faster generation.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
✨ **SageAttention Integration** - 2-5x faster than FlashAttention with automatic fallback
|
| 8 |
+
🎨 **Gradio Web UI** - Easy-to-use interface for image generation
|
| 9 |
+
💾 **Q8_0 Quantization** - ~75% memory reduction with minimal quality loss
|
| 10 |
+
🚀 **Optimized Inference** - T5 encoder on CPU, efficient VRAM usage
|
| 11 |
+
🔧 **GGUF Support** - On-the-fly dequantization with flexible deployment
|
| 12 |
+
|
| 13 |
+
## Quick Start
|
| 14 |
+
|
| 15 |
+
### Web UI (Recommended)
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
python gradio_webui.py --autoload
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Then open `http://127.0.0.1:7860` in your browser.
|
| 22 |
+
|
| 23 |
+
### Command Line
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
python generate_image_2b_q8_gguf.py \
|
| 27 |
+
--prompt "an astronaut riding a horse on the moon" \
|
| 28 |
+
--output output.png
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Installation
|
| 32 |
+
|
| 33 |
+
### 1. Basic Requirements
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
pip install -r Infinity/requirements.txt
|
| 37 |
+
pip install gradio gguf
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
### 2. Install SageAttention (Optional, Recommended)
|
| 41 |
+
|
| 42 |
+
For faster generation:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
pip install sageattention>=2.2.0 --no-build-isolation
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
**Requirements**: CUDA ≥12.0 (CUDA 12.8+ for Blackwell GPUs like RTX 50-series)
|
| 49 |
+
|
| 50 |
+
**Note**: SageAttention is optional. The code automatically falls back to:
|
| 51 |
+
1. SageAttention (if installed) - 2-5x faster ✨
|
| 52 |
+
2. FlashAttention (if available) - faster than PyTorch
|
| 53 |
+
3. PyTorch SDPA (always works) - built-in fallback
|
| 54 |
+
|
| 55 |
+
### 3. Download Models
|
| 56 |
+
|
| 57 |
+
You'll need:
|
| 58 |
+
- `infinity_2b_reg_Q8_0.gguf` - Infinity-2B model (~2.1 GB)
|
| 59 |
+
- `flan-t5-xl-encoder-Q8_0.gguf` - T5 text encoder (~1.0 GB)
|
| 60 |
+
- `Infinity/infinity_vae_d32_reg.pth` - VAE decoder (~0.5 GB)
|
| 61 |
+
|
| 62 |
+
## Memory Requirements
|
| 63 |
+
|
| 64 |
+
| Component | VRAM Usage |
|
| 65 |
+
|-----------|-----------|
|
| 66 |
+
| Infinity-2B (Q8_0) | ~2.5 GB |
|
| 67 |
+
| VAE | ~0.5 GB |
|
| 68 |
+
| Working Memory | ~1-2 GB |
|
| 69 |
+
| **Total (1M res)** | **~4-5 GB** |
|
| 70 |
+
|
| 71 |
+
**T5 encoder runs on CPU** to save VRAM!
|
| 72 |
+
|
| 73 |
+
Recommended: **8GB+ VRAM** for comfortable 1M (1024×1024) generation
|
| 74 |
+
|
| 75 |
+
## Web UI Features
|
| 76 |
+
|
| 77 |
+
The Gradio web interface provides:
|
| 78 |
+
|
| 79 |
+
- **Model Management**: Load models once, reuse for all generations
|
| 80 |
+
- **Full Parameter Control**: CFG scale, tau, resolution, aspect ratio, seed
|
| 81 |
+
- **Real-time Preview**: See your images as they generate
|
| 82 |
+
- **Progress Tracking**: Visual feedback during loading and generation
|
| 83 |
+
- **Clean Layout**: Model paths banner, settings on left, output on right
|
| 84 |
+
|
| 85 |
+
### Web UI Options
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# Basic usage
|
| 89 |
+
python gradio_webui.py
|
| 90 |
+
|
| 91 |
+
# Auto-load models on startup (faster)
|
| 92 |
+
python gradio_webui.py --autoload
|
| 93 |
+
|
| 94 |
+
# Create public share link
|
| 95 |
+
python gradio_webui.py --share
|
| 96 |
+
|
| 97 |
+
# Custom port
|
| 98 |
+
python gradio_webui.py --server-port 8080
|
| 99 |
+
|
| 100 |
+
# Full options
|
| 101 |
+
python gradio_webui.py \
|
| 102 |
+
--autoload \
|
| 103 |
+
--server-port 7860 \
|
| 104 |
+
--infinity-gguf path/to/infinity.gguf \
|
| 105 |
+
--t5-gguf path/to/t5.gguf \
|
| 106 |
+
--vae-path path/to/vae.pth
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## Command-Line Options
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
python generate_image_2b_q8_gguf.py [OPTIONS]
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
| Option | Description | Default |
|
| 116 |
+
|--------|-------------|---------|
|
| 117 |
+
| `--prompt TEXT` | Text prompt for image generation | "an astronaut..." |
|
| 118 |
+
| `--infinity-gguf PATH` | Path to Infinity GGUF file | infinity_2b_reg_Q8_0.gguf |
|
| 119 |
+
| `--t5-gguf PATH` | Path to T5 encoder GGUF | flan-t5-xl-encoder-Q8_0.gguf |
|
| 120 |
+
| `--vae-path PATH` | Path to VAE checkpoint | Infinity/infinity_vae_d32_reg.pth |
|
| 121 |
+
| `--output PATH` | Output image path | output.png |
|
| 122 |
+
| `--cfg-scale FLOAT` | CFG scale (1.0-10.0) | 3.0 |
|
| 123 |
+
| `--tau FLOAT` | Temperature (0.1-1.0) | 0.5 |
|
| 124 |
+
| `--seed INT` | Random seed for reproducibility | 42 |
|
| 125 |
+
| `--pn {0.06M,0.25M,1M}` | Resolution preset | 1M |
|
| 126 |
+
| `--aspect-ratio FLOAT` | Aspect ratio (height/width) | 1.0 |
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
## Technical Details
|
| 130 |
+
|
| 131 |
+
### Quantization
|
| 132 |
+
|
| 133 |
+
- **Q8_0 format**: 8-bit quantization with minimal quality loss
|
| 134 |
+
- **On-the-fly dequantization**: Using custom GGUFLinear layers
|
| 135 |
+
- **Memory savings**: ~75% reduction vs FP16
|
| 136 |
+
- **Quality**: Nearly identical to FP16
|
| 137 |
+
|
| 138 |
+
### Architecture
|
| 139 |
+
|
| 140 |
+
- **Infinity-2B**: 2.0B parameters, embed_dim=2048, depth=32
|
| 141 |
+
- **T5-XL Encoder**: 2048-dim text embeddings
|
| 142 |
+
- **VAE**: d32 with dynamic resolution support
|
| 143 |
+
|
| 144 |
+
### GGUF Support
|
| 145 |
+
|
| 146 |
+
The implementation includes:
|
| 147 |
+
- Import utilities for GGUF tensors
|
| 148 |
+
- Custom `GGUFLinear` layers for on-the-fly dequantization
|
| 149 |
+
- Patched attention mechanisms for compatibility
|
| 150 |
+
- F16 dtype handling for head layers
|
| 151 |
+
|
| 152 |
+
See [patch_infinity_for_gguf.sh](patch_infinity_for_gguf.sh) for implementation details.
|
| 153 |
+
|
| 154 |
+
## Credits
|
| 155 |
+
|
| 156 |
+
- **Original Model**: [Infinity by FoundationVision](https://github.com/FoundationVision/Infinity)
|
| 157 |
+
- **SageAttention**: [thu-ml/SageAttention](https://github.com/thu-ml/SageAttention)
|
| 158 |
+
- **GGUF Format**: [ggerganov/ggml](https://github.com/ggerganov/ggml)
|
| 159 |
+
|
| 160 |
+
## License
|
| 161 |
+
|
| 162 |
+
MIT
|
flan-t5-xl-encoder-Q8_0.gguf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d212c960e07faf2323e2136cb03e62578a8f6862f13709f480684da9f5d9a2e6
|
| 3 |
+
size 1563507296
|
generate_image_2b_q8_gguf.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Generate images using quantized Infinity-2B model (GGUF format)
|
| 4 |
+
Loads T5 text encoder from GGUF on CPU, Infinity model from GGUF on GPU
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 10 |
+
|
| 11 |
+
# Add Infinity to Python path (assumes Infinity repo is in same directory as this script)
|
| 12 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
INFINITY_PATH = os.path.join(SCRIPT_DIR, 'Infinity')
|
| 14 |
+
if os.path.exists(INFINITY_PATH):
|
| 15 |
+
sys.path.insert(0, INFINITY_PATH)
|
| 16 |
+
else:
|
| 17 |
+
print(f"Warning: Infinity repo not found at {INFINITY_PATH}")
|
| 18 |
+
print("Please clone the Infinity repo and run patch_infinity_for_gguf.sh")
|
| 19 |
+
|
| 20 |
+
import time
|
| 21 |
+
import argparse
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import numpy as np
|
| 25 |
+
import cv2
|
| 26 |
+
from typing import List
|
| 27 |
+
import gguf
|
| 28 |
+
|
| 29 |
+
# Import existing utilities
|
| 30 |
+
from infinity_gguf_utils import (
|
| 31 |
+
load_gguf_state_dict,
|
| 32 |
+
load_gguf_state_dict_with_params,
|
| 33 |
+
_replace_with_gguf_linear,
|
| 34 |
+
GGUFParameter,
|
| 35 |
+
dequantize_gguf_tensor,
|
| 36 |
+
GGUFLinear
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Import Infinity model and utilities
|
| 40 |
+
from infinity.models.infinity import Infinity
|
| 41 |
+
from infinity.models.bsq_vae.vae import vae_model
|
| 42 |
+
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
|
| 43 |
+
|
| 44 |
+
# Import transformers for tokenizer
|
| 45 |
+
from transformers import AutoTokenizer
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_t5_tokenizer_from_gguf(gguf_path):
|
| 49 |
+
"""
|
| 50 |
+
Load T5 tokenizer from GGUF metadata or use standard tokenizer
|
| 51 |
+
For simplicity, we'll use the standard T5 tokenizer
|
| 52 |
+
"""
|
| 53 |
+
print("[Loading T5 Tokenizer]")
|
| 54 |
+
# Use standard T5 tokenizer - the GGUF file should be compatible
|
| 55 |
+
# We can use any T5-v1.1-xxl tokenizer since the vocab is standard
|
| 56 |
+
try:
|
| 57 |
+
from transformers import T5TokenizerFast
|
| 58 |
+
# Try to find a local tokenizer or use HuggingFace
|
| 59 |
+
tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl", legacy=True)
|
| 60 |
+
tokenizer.model_max_length = 512
|
| 61 |
+
return tokenizer
|
| 62 |
+
except:
|
| 63 |
+
print("Warning: Could not load T5 tokenizer from HuggingFace, trying local cache...")
|
| 64 |
+
tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xxl", legacy=True)
|
| 65 |
+
tokenizer.model_max_length = 512
|
| 66 |
+
return tokenizer
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_t5_encoder_from_gguf(gguf_path, device='cpu'):
|
| 70 |
+
"""
|
| 71 |
+
Load T5 encoder from GGUF file and keep on CPU
|
| 72 |
+
Based on ComfyUI-GGUF loader implementation
|
| 73 |
+
"""
|
| 74 |
+
print(f"[Loading T5 Encoder from GGUF: {gguf_path}]")
|
| 75 |
+
print(f"[T5 will be kept on {device}]")
|
| 76 |
+
|
| 77 |
+
# Apply NumPy 2.0 compatibility patch if needed
|
| 78 |
+
import numpy as np
|
| 79 |
+
if not hasattr(np.ndarray, 'newbyteorder'):
|
| 80 |
+
def newbyteorder(self, new_order):
|
| 81 |
+
return self.view(self.dtype.newbyteorder(new_order))
|
| 82 |
+
np.ndarray.newbyteorder = newbyteorder
|
| 83 |
+
|
| 84 |
+
# Load GGUF state dict
|
| 85 |
+
from gguf import GGUFReader
|
| 86 |
+
reader = GGUFReader(gguf_path)
|
| 87 |
+
|
| 88 |
+
# Map llama.cpp T5 keys to HuggingFace T5 keys
|
| 89 |
+
T5_SD_MAP = {
|
| 90 |
+
"enc.": "encoder.",
|
| 91 |
+
".blk.": ".block.",
|
| 92 |
+
"token_embd": "shared",
|
| 93 |
+
"output_norm": "final_layer_norm",
|
| 94 |
+
"attn_q": "layer.0.SelfAttention.q",
|
| 95 |
+
"attn_k": "layer.0.SelfAttention.k",
|
| 96 |
+
"attn_v": "layer.0.SelfAttention.v",
|
| 97 |
+
"attn_o": "layer.0.SelfAttention.o",
|
| 98 |
+
"attn_norm": "layer.0.layer_norm",
|
| 99 |
+
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
|
| 100 |
+
"ffn_up": "layer.1.DenseReluDense.wi_1",
|
| 101 |
+
"ffn_down": "layer.1.DenseReluDense.wo",
|
| 102 |
+
"ffn_gate": "layer.1.DenseReluDense.wi_0",
|
| 103 |
+
"ffn_norm": "layer.1.layer_norm",
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Load and convert tensors
|
| 107 |
+
state_dict = {}
|
| 108 |
+
print("Loading T5 tensors from GGUF...")
|
| 109 |
+
for tensor in reader.tensors:
|
| 110 |
+
tensor_name = tensor.name
|
| 111 |
+
|
| 112 |
+
# Apply key mapping
|
| 113 |
+
for old_key, new_key in T5_SD_MAP.items():
|
| 114 |
+
tensor_name = tensor_name.replace(old_key, new_key)
|
| 115 |
+
|
| 116 |
+
# Load tensor data
|
| 117 |
+
torch_tensor = torch.from_numpy(np.array(tensor.data))
|
| 118 |
+
|
| 119 |
+
# Determine shape
|
| 120 |
+
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
| 121 |
+
|
| 122 |
+
# Check if quantized
|
| 123 |
+
is_quantized = tensor.tensor_type not in {
|
| 124 |
+
gguf.GGMLQuantizationType.F32,
|
| 125 |
+
gguf.GGMLQuantizationType.F16
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
if is_quantized:
|
| 129 |
+
# Dequantize to float16 for CPU inference
|
| 130 |
+
# print(f" Dequantizing {tensor_name} ({tensor.tensor_type})...")
|
| 131 |
+
param = GGUFParameter(torch_tensor, quant_type=tensor.tensor_type)
|
| 132 |
+
dequant_tensor = dequantize_gguf_tensor(param, target_dtype=torch.float16)
|
| 133 |
+
state_dict[tensor_name] = dequant_tensor.to(device)
|
| 134 |
+
else:
|
| 135 |
+
# Already F32 or F16
|
| 136 |
+
torch_tensor = torch_tensor.view(*shape)
|
| 137 |
+
if tensor.tensor_type == gguf.GGMLQuantizationType.F32:
|
| 138 |
+
state_dict[tensor_name] = torch_tensor.to(torch.float16).to(device)
|
| 139 |
+
else:
|
| 140 |
+
state_dict[tensor_name] = torch_tensor.to(device)
|
| 141 |
+
|
| 142 |
+
print(f"Loaded {len(state_dict)} tensors for T5 encoder")
|
| 143 |
+
|
| 144 |
+
# Load T5 model architecture from transformers
|
| 145 |
+
from transformers import T5EncoderModel, T5Config
|
| 146 |
+
|
| 147 |
+
# Create T5 config - for T5-XL (2048 dims, not XXL which is 4096)
|
| 148 |
+
# Try to load from local directory first, fall back to download if needed
|
| 149 |
+
try:
|
| 150 |
+
config = T5Config.from_pretrained("./flan-t5-xl-official")
|
| 151 |
+
print("Loaded T5 config from local directory")
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"Could not load config from local directory: {e}")
|
| 154 |
+
print("Falling back to download T5 config...")
|
| 155 |
+
config = T5Config.from_pretrained("google/flan-t5-xl")
|
| 156 |
+
print("Downloaded T5 config from HuggingFace")
|
| 157 |
+
|
| 158 |
+
# Create model
|
| 159 |
+
model = T5EncoderModel(config)
|
| 160 |
+
|
| 161 |
+
# Load state dict
|
| 162 |
+
print("Loading state dict into T5 model...")
|
| 163 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 164 |
+
if missing:
|
| 165 |
+
print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}")
|
| 166 |
+
if unexpected:
|
| 167 |
+
print(f" Unexpected keys: {unexpected[:5]}..." if len(unexpected) > 5 else f" Unexpected keys: {unexpected}")
|
| 168 |
+
|
| 169 |
+
model.to(device)
|
| 170 |
+
model.eval()
|
| 171 |
+
model.requires_grad_(False)
|
| 172 |
+
|
| 173 |
+
print(f"[T5 Encoder loaded successfully on {device}]")
|
| 174 |
+
return model
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def load_infinity_from_gguf(gguf_path, vae, device='cuda', model_type='infinity_2b',
|
| 178 |
+
text_channels=2048, pn='1M'):
|
| 179 |
+
"""
|
| 180 |
+
Load Infinity model from GGUF file
|
| 181 |
+
"""
|
| 182 |
+
print(f"[Loading Infinity-2B from GGUF: {gguf_path}]")
|
| 183 |
+
|
| 184 |
+
# Model configuration for Infinity-2B
|
| 185 |
+
if model_type == 'infinity_2b':
|
| 186 |
+
kwargs_model = dict(
|
| 187 |
+
depth=32,
|
| 188 |
+
embed_dim=2048,
|
| 189 |
+
num_heads=2048//128, # 16 heads
|
| 190 |
+
drop_path_rate=0.1,
|
| 191 |
+
mlp_ratio=4,
|
| 192 |
+
block_chunks=8
|
| 193 |
+
)
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError(f"Unsupported model type: {model_type}")
|
| 196 |
+
|
| 197 |
+
# Create Infinity model
|
| 198 |
+
text_maxlen = 512
|
| 199 |
+
print("[Creating Infinity model architecture]")
|
| 200 |
+
|
| 201 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
|
| 202 |
+
infinity_model = Infinity(
|
| 203 |
+
vae_local=vae,
|
| 204 |
+
text_channels=text_channels,
|
| 205 |
+
text_maxlen=text_maxlen,
|
| 206 |
+
shared_aln=True,
|
| 207 |
+
raw_scale_schedule=None,
|
| 208 |
+
checkpointing='full-block',
|
| 209 |
+
customized_flash_attn=False,
|
| 210 |
+
fused_norm=True,
|
| 211 |
+
pad_to_multiplier=128,
|
| 212 |
+
use_flex_attn=False,
|
| 213 |
+
add_lvl_embeding_only_first_block=1,
|
| 214 |
+
use_bit_label=1,
|
| 215 |
+
rope2d_each_sa_layer=1,
|
| 216 |
+
rope2d_normalized_by_hw=2,
|
| 217 |
+
pn=pn,
|
| 218 |
+
apply_spatial_patchify=0,
|
| 219 |
+
inference_mode=True,
|
| 220 |
+
train_h_div_w_list=[1.0],
|
| 221 |
+
**kwargs_model,
|
| 222 |
+
).to(device=device)
|
| 223 |
+
|
| 224 |
+
print(f"[Infinity model size: {sum(p.numel() for p in infinity_model.parameters())/1e9:.2f}B parameters]")
|
| 225 |
+
|
| 226 |
+
# Convert to bfloat16
|
| 227 |
+
for block in infinity_model.unregistered_blocks:
|
| 228 |
+
block.bfloat16()
|
| 229 |
+
|
| 230 |
+
infinity_model.eval()
|
| 231 |
+
infinity_model.requires_grad_(False)
|
| 232 |
+
|
| 233 |
+
# Load GGUF weights with GGUFParameters
|
| 234 |
+
print("[Loading Infinity weights from GGUF]")
|
| 235 |
+
state_dict = load_gguf_state_dict_with_params(gguf_path, device=device)
|
| 236 |
+
|
| 237 |
+
# Replace Linear layers with GGUFLinear layers for on-the-fly dequantization
|
| 238 |
+
print("[Replacing Linear layers with GGUFLinear layers]")
|
| 239 |
+
infinity_model = _replace_with_gguf_linear(infinity_model, torch.bfloat16, state_dict, prefix="")
|
| 240 |
+
|
| 241 |
+
# Load weights directly into the model (not using load_state_dict)
|
| 242 |
+
print("[Loading weights into model]")
|
| 243 |
+
skipped_keys = []
|
| 244 |
+
for key, tensor in state_dict.items():
|
| 245 |
+
# Find the module and parameter name
|
| 246 |
+
parts = key.rsplit('.', 1)
|
| 247 |
+
if len(parts) != 2:
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
module_name, param_name = parts
|
| 251 |
+
|
| 252 |
+
# Navigate to the module
|
| 253 |
+
module = infinity_model
|
| 254 |
+
for attr in module_name.split('.'):
|
| 255 |
+
if hasattr(module, attr):
|
| 256 |
+
module = getattr(module, attr)
|
| 257 |
+
else:
|
| 258 |
+
module = None
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
# Set the parameter
|
| 262 |
+
if module is not None and hasattr(module, param_name):
|
| 263 |
+
existing_param = getattr(module, param_name)
|
| 264 |
+
|
| 265 |
+
# Get the shape of the tensor to load
|
| 266 |
+
tensor_shape = tensor.shape
|
| 267 |
+
if hasattr(tensor, 'quant_shape'):
|
| 268 |
+
tensor_shape = tensor.quant_shape
|
| 269 |
+
|
| 270 |
+
# Check if shapes match
|
| 271 |
+
if existing_param.shape != tensor_shape:
|
| 272 |
+
print(f"[WARNING] Shape mismatch for {key}: expected {existing_param.shape}, got {tensor_shape}. Skipping.")
|
| 273 |
+
skipped_keys.append(key)
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
# Set the parameter
|
| 277 |
+
if isinstance(tensor, torch.nn.Parameter):
|
| 278 |
+
setattr(module, param_name, tensor)
|
| 279 |
+
else:
|
| 280 |
+
setattr(module, param_name, torch.nn.Parameter(tensor, requires_grad=False))
|
| 281 |
+
|
| 282 |
+
if skipped_keys:
|
| 283 |
+
print(f"[INFO] Skipped {len(skipped_keys)} parameters due to shape mismatches")
|
| 284 |
+
|
| 285 |
+
infinity_model.rng = torch.Generator(device=device)
|
| 286 |
+
|
| 287 |
+
print("[Infinity model loaded successfully]")
|
| 288 |
+
return infinity_model
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def load_vae(vae_path, vae_type=32, device='cuda'):
|
| 292 |
+
"""
|
| 293 |
+
Load VAE model
|
| 294 |
+
"""
|
| 295 |
+
print(f"[Loading VAE from {vae_path}]")
|
| 296 |
+
|
| 297 |
+
schedule_mode = "dynamic"
|
| 298 |
+
codebook_dim = vae_type
|
| 299 |
+
codebook_size = 2**codebook_dim
|
| 300 |
+
patch_size = 16
|
| 301 |
+
encoder_ch_mult = [1, 2, 4, 4, 4]
|
| 302 |
+
decoder_ch_mult = [1, 2, 4, 4, 4]
|
| 303 |
+
|
| 304 |
+
vae = vae_model(
|
| 305 |
+
vae_path,
|
| 306 |
+
schedule_mode,
|
| 307 |
+
codebook_dim,
|
| 308 |
+
codebook_size,
|
| 309 |
+
patch_size=patch_size,
|
| 310 |
+
encoder_ch_mult=encoder_ch_mult,
|
| 311 |
+
decoder_ch_mult=decoder_ch_mult,
|
| 312 |
+
test_mode=True
|
| 313 |
+
).to(device)
|
| 314 |
+
|
| 315 |
+
print("[VAE loaded successfully]")
|
| 316 |
+
return vae
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def encode_prompt(text_tokenizer, text_encoder, prompt, device='cuda'):
|
| 320 |
+
"""
|
| 321 |
+
Encode text prompt using T5 encoder
|
| 322 |
+
"""
|
| 323 |
+
print(f"Encoding prompt: {prompt}")
|
| 324 |
+
|
| 325 |
+
captions = [prompt]
|
| 326 |
+
tokens = text_tokenizer(
|
| 327 |
+
text=captions,
|
| 328 |
+
max_length=512,
|
| 329 |
+
padding='max_length',
|
| 330 |
+
truncation=True,
|
| 331 |
+
return_tensors='pt'
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Move tokens to appropriate devices
|
| 335 |
+
# T5 encoder is on CPU, so keep tokens on CPU too
|
| 336 |
+
input_ids = tokens.input_ids.to(text_encoder.device)
|
| 337 |
+
mask = tokens.attention_mask.to(text_encoder.device)
|
| 338 |
+
|
| 339 |
+
# Encode with T5
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
text_features = text_encoder(
|
| 342 |
+
input_ids=input_ids,
|
| 343 |
+
attention_mask=mask
|
| 344 |
+
)['last_hidden_state'].float()
|
| 345 |
+
|
| 346 |
+
# Move to GPU for Infinity model
|
| 347 |
+
text_features = text_features.to(device)
|
| 348 |
+
mask = mask.to(device)
|
| 349 |
+
|
| 350 |
+
lens: List[int] = mask.sum(dim=-1).tolist()
|
| 351 |
+
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
|
| 352 |
+
Ltext = max(lens)
|
| 353 |
+
|
| 354 |
+
kv_compact = []
|
| 355 |
+
for len_i, feat_i in zip(lens, text_features.unbind(0)):
|
| 356 |
+
kv_compact.append(feat_i[:len_i])
|
| 357 |
+
kv_compact = torch.cat(kv_compact, dim=0)
|
| 358 |
+
|
| 359 |
+
# Ensure kv_compact is in float32 to avoid dtype mismatches
|
| 360 |
+
kv_compact = kv_compact.to(torch.float32)
|
| 361 |
+
|
| 362 |
+
text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
|
| 363 |
+
return text_cond_tuple
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def generate_image(infinity_model, vae, text_tokenizer, text_encoder, prompt,
|
| 367 |
+
cfg_scale=3.0, tau=0.5, seed=None, scale_schedule=None,
|
| 368 |
+
vae_type=32, device='cuda'):
|
| 369 |
+
"""
|
| 370 |
+
Generate image using Infinity model
|
| 371 |
+
"""
|
| 372 |
+
print("[Starting image generation]")
|
| 373 |
+
start_time = time.time()
|
| 374 |
+
|
| 375 |
+
# Note: Deterministic mode is set early in main() if seed is provided
|
| 376 |
+
if seed is not None:
|
| 377 |
+
print(f"Using seed: {seed}")
|
| 378 |
+
|
| 379 |
+
# Encode prompt
|
| 380 |
+
text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, device=device)
|
| 381 |
+
|
| 382 |
+
# Prepare cfg and tau lists
|
| 383 |
+
cfg_list = [cfg_scale] * len(scale_schedule)
|
| 384 |
+
tau_list = [tau] * len(scale_schedule)
|
| 385 |
+
|
| 386 |
+
print(f"CFG scale: {cfg_scale}, Tau: {tau}")
|
| 387 |
+
print(f"Scale schedule: {scale_schedule}")
|
| 388 |
+
|
| 389 |
+
# Generate with autocast
|
| 390 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
|
| 391 |
+
with torch.no_grad():
|
| 392 |
+
gen_start = time.time()
|
| 393 |
+
|
| 394 |
+
_, _, img_list = infinity_model.autoregressive_infer_cfg(
|
| 395 |
+
vae=vae,
|
| 396 |
+
scale_schedule=scale_schedule,
|
| 397 |
+
label_B_or_BLT=text_cond_tuple,
|
| 398 |
+
g_seed=seed,
|
| 399 |
+
B=1,
|
| 400 |
+
negative_label_B_or_BLT=None,
|
| 401 |
+
force_gt_Bhw=None,
|
| 402 |
+
cfg_sc=cfg_scale,
|
| 403 |
+
cfg_list=cfg_list,
|
| 404 |
+
tau_list=tau_list,
|
| 405 |
+
top_k=900,
|
| 406 |
+
top_p=0.97,
|
| 407 |
+
returns_vemb=1,
|
| 408 |
+
ratio_Bl1=None,
|
| 409 |
+
gumbel=0,
|
| 410 |
+
norm_cfg=False,
|
| 411 |
+
cfg_exp_k=0.0,
|
| 412 |
+
cfg_insertion_layer=[0], # Must be a list
|
| 413 |
+
vae_type=vae_type,
|
| 414 |
+
softmax_merge_topk=-1,
|
| 415 |
+
ret_img=True,
|
| 416 |
+
trunk_scale=1000,
|
| 417 |
+
gt_leak=0,
|
| 418 |
+
gt_ls_Bl=None,
|
| 419 |
+
inference_mode=True,
|
| 420 |
+
sampling_per_bits=1,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
gen_time = time.time() - gen_start
|
| 424 |
+
|
| 425 |
+
img = img_list[0]
|
| 426 |
+
|
| 427 |
+
total_time = time.time() - start_time
|
| 428 |
+
print(f"[Generation complete! Total time: {total_time:.2f}s, Inference time: {gen_time:.2f}s]")
|
| 429 |
+
|
| 430 |
+
return img
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def main():
|
| 434 |
+
parser = argparse.ArgumentParser(description='Generate images with Infinity-2B GGUF')
|
| 435 |
+
parser.add_argument('--prompt', type=str,
|
| 436 |
+
default='an astronaut riding a horse on the moon',
|
| 437 |
+
help='Text prompt for image generation')
|
| 438 |
+
parser.add_argument('--infinity-gguf', type=str,
|
| 439 |
+
default='infinity_2b_reg_Q8_0.gguf',
|
| 440 |
+
help='Path to Infinity-2B GGUF file')
|
| 441 |
+
parser.add_argument('--t5-gguf', type=str,
|
| 442 |
+
default='flan-t5-xl-encoder-Q8_0.gguf',
|
| 443 |
+
help='Path to T5 encoder GGUF file')
|
| 444 |
+
parser.add_argument('--vae-path', type=str,
|
| 445 |
+
default='Infinity/infinity_vae_d32_reg.pth',
|
| 446 |
+
help='Path to VAE checkpoint')
|
| 447 |
+
parser.add_argument('--output', type=str,
|
| 448 |
+
default='output.png',
|
| 449 |
+
help='Output image path')
|
| 450 |
+
parser.add_argument('--cfg-scale', type=float, default=3.0,
|
| 451 |
+
help='Classifier-free guidance scale')
|
| 452 |
+
parser.add_argument('--tau', type=float, default=0.5,
|
| 453 |
+
help='Temperature for self-attention')
|
| 454 |
+
parser.add_argument('--seed', type=int, default=42,
|
| 455 |
+
help='Random seed')
|
| 456 |
+
parser.add_argument('--pn', type=str, default='1M',
|
| 457 |
+
choices=['0.06M', '0.25M', '1M'],
|
| 458 |
+
help='Resolution preset')
|
| 459 |
+
parser.add_argument('--aspect-ratio', type=float, default=1.0,
|
| 460 |
+
help='Aspect ratio (height/width)')
|
| 461 |
+
|
| 462 |
+
args = parser.parse_args()
|
| 463 |
+
|
| 464 |
+
# Set deterministic mode early (before model loading) if seed is provided
|
| 465 |
+
if args.seed is not None:
|
| 466 |
+
torch.manual_seed(args.seed)
|
| 467 |
+
np.random.seed(args.seed)
|
| 468 |
+
|
| 469 |
+
# Enable deterministic mode for cuDNN
|
| 470 |
+
torch.backends.cudnn.deterministic = True
|
| 471 |
+
torch.backends.cudnn.benchmark = False
|
| 472 |
+
|
| 473 |
+
# Try to enable full deterministic mode
|
| 474 |
+
try:
|
| 475 |
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
| 476 |
+
except Exception as e:
|
| 477 |
+
print(f"Warning: Could not enable full deterministic mode: {e}")
|
| 478 |
+
|
| 479 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 480 |
+
print(f"Using device: {device}")
|
| 481 |
+
|
| 482 |
+
# Set CUDA seed after device is determined
|
| 483 |
+
if args.seed is not None and device == 'cuda':
|
| 484 |
+
torch.cuda.manual_seed(args.seed)
|
| 485 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 486 |
+
|
| 487 |
+
# Control SDPA backend for determinism
|
| 488 |
+
try:
|
| 489 |
+
torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False)
|
| 490 |
+
print(f"Deterministic mode enabled (seed={args.seed})")
|
| 491 |
+
except Exception as e:
|
| 492 |
+
print(f"Warning: Could not set SDPA backend: {e}")
|
| 493 |
+
|
| 494 |
+
if device == 'cpu':
|
| 495 |
+
print("WARNING: No GPU detected! This will be extremely slow.")
|
| 496 |
+
|
| 497 |
+
# Determine scale schedule based on aspect ratio
|
| 498 |
+
h_div_w_template = h_div_w_templates[
|
| 499 |
+
np.argmin(np.abs(h_div_w_templates - args.aspect_ratio))
|
| 500 |
+
]
|
| 501 |
+
scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales']
|
| 502 |
+
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
|
| 503 |
+
|
| 504 |
+
print("\n" + "="*70)
|
| 505 |
+
print("Infinity-2B GGUF Image Generation")
|
| 506 |
+
print("="*70)
|
| 507 |
+
|
| 508 |
+
# Load models
|
| 509 |
+
print("\n[1/4] Loading VAE...")
|
| 510 |
+
vae = load_vae(args.vae_path, vae_type=32, device=device)
|
| 511 |
+
|
| 512 |
+
print("\n[2/4] Loading T5 Tokenizer...")
|
| 513 |
+
text_tokenizer = load_t5_tokenizer_from_gguf(args.t5_gguf)
|
| 514 |
+
|
| 515 |
+
print("\n[3/4] Loading T5 Encoder from GGUF (on CPU)...")
|
| 516 |
+
text_encoder = load_t5_encoder_from_gguf(args.t5_gguf, device='cpu')
|
| 517 |
+
|
| 518 |
+
print("\n[4/4] Loading Infinity-2B from GGUF...")
|
| 519 |
+
infinity_model = load_infinity_from_gguf(
|
| 520 |
+
args.infinity_gguf,
|
| 521 |
+
vae=vae,
|
| 522 |
+
device=device,
|
| 523 |
+
model_type='infinity_2b',
|
| 524 |
+
text_channels=2048, # Model projects T5's 4096 internally
|
| 525 |
+
pn=args.pn
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
print("\n" + "="*70)
|
| 529 |
+
print("All models loaded successfully!")
|
| 530 |
+
print("="*70)
|
| 531 |
+
|
| 532 |
+
# Generate image
|
| 533 |
+
print(f"\nGenerating image with prompt: '{args.prompt}'")
|
| 534 |
+
generated_image = generate_image(
|
| 535 |
+
infinity_model,
|
| 536 |
+
vae,
|
| 537 |
+
text_tokenizer,
|
| 538 |
+
text_encoder,
|
| 539 |
+
args.prompt,
|
| 540 |
+
cfg_scale=args.cfg_scale,
|
| 541 |
+
tau=args.tau,
|
| 542 |
+
seed=args.seed,
|
| 543 |
+
scale_schedule=scale_schedule,
|
| 544 |
+
vae_type=32,
|
| 545 |
+
device=device
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# Save image
|
| 549 |
+
print(f"\nSaving image to {args.output}...")
|
| 550 |
+
image_np = generated_image.cpu().numpy()
|
| 551 |
+
cv2.imwrite(args.output, image_np)
|
| 552 |
+
|
| 553 |
+
print(f"\n{'='*70}")
|
| 554 |
+
print(f"✓ Image saved successfully to: {args.output}")
|
| 555 |
+
print(f"{'='*70}\n")
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
if __name__ == '__main__':
|
| 559 |
+
main()
|
gradio_webui.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gradio Web UI for Infinity-2B GGUF Image Generation
|
| 4 |
+
Provides an easy-to-use interface for generating images with the quantized model
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 10 |
+
|
| 11 |
+
# Add Infinity to Python path
|
| 12 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
INFINITY_PATH = os.path.join(SCRIPT_DIR, 'Infinity')
|
| 14 |
+
if os.path.exists(INFINITY_PATH):
|
| 15 |
+
sys.path.insert(0, INFINITY_PATH)
|
| 16 |
+
|
| 17 |
+
import time
|
| 18 |
+
import argparse
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
import gradio as gr
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
|
| 25 |
+
# Import the generation functions from our existing script
|
| 26 |
+
from generate_image_2b_q8_gguf import (
|
| 27 |
+
load_t5_tokenizer_from_gguf,
|
| 28 |
+
load_t5_encoder_from_gguf,
|
| 29 |
+
load_infinity_from_gguf,
|
| 30 |
+
load_vae,
|
| 31 |
+
generate_image
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Global model storage
|
| 38 |
+
class ModelCache:
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self.vae = None
|
| 41 |
+
self.text_tokenizer = None
|
| 42 |
+
self.text_encoder = None
|
| 43 |
+
self.infinity_model = None
|
| 44 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 45 |
+
self.loaded = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
model_cache = ModelCache()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_models(infinity_gguf_path, t5_gguf_path, vae_path, pn='1M', progress=gr.Progress()):
|
| 52 |
+
"""
|
| 53 |
+
Load all models with progress tracking
|
| 54 |
+
"""
|
| 55 |
+
global model_cache
|
| 56 |
+
|
| 57 |
+
if model_cache.loaded:
|
| 58 |
+
return "✓ Models already loaded!"
|
| 59 |
+
|
| 60 |
+
progress(0, desc="Loading VAE...")
|
| 61 |
+
model_cache.vae = load_vae(vae_path, vae_type=32, device=model_cache.device)
|
| 62 |
+
|
| 63 |
+
progress(0.25, desc="Loading T5 Tokenizer...")
|
| 64 |
+
model_cache.text_tokenizer = load_t5_tokenizer_from_gguf(t5_gguf_path)
|
| 65 |
+
|
| 66 |
+
progress(0.5, desc="Loading T5 Encoder (on CPU)...")
|
| 67 |
+
model_cache.text_encoder = load_t5_encoder_from_gguf(t5_gguf_path, device='cpu')
|
| 68 |
+
|
| 69 |
+
progress(0.75, desc="Loading Infinity-2B from GGUF...")
|
| 70 |
+
model_cache.infinity_model = load_infinity_from_gguf(
|
| 71 |
+
infinity_gguf_path,
|
| 72 |
+
vae=model_cache.vae,
|
| 73 |
+
device=model_cache.device,
|
| 74 |
+
model_type='infinity_2b',
|
| 75 |
+
text_channels=2048,
|
| 76 |
+
pn=pn
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
model_cache.loaded = True
|
| 80 |
+
progress(1.0, desc="Complete!")
|
| 81 |
+
|
| 82 |
+
return "✓ All models loaded successfully!"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def generate_image_gradio(
|
| 86 |
+
prompt,
|
| 87 |
+
cfg_scale,
|
| 88 |
+
tau,
|
| 89 |
+
seed,
|
| 90 |
+
aspect_ratio,
|
| 91 |
+
pn,
|
| 92 |
+
use_random_seed,
|
| 93 |
+
progress=gr.Progress()
|
| 94 |
+
):
|
| 95 |
+
"""
|
| 96 |
+
Generate image with Gradio progress tracking
|
| 97 |
+
"""
|
| 98 |
+
global model_cache
|
| 99 |
+
|
| 100 |
+
if not model_cache.loaded:
|
| 101 |
+
return None, "❌ Please load models first!"
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
# Use random seed if requested
|
| 105 |
+
if use_random_seed:
|
| 106 |
+
seed = np.random.randint(0, 2**31 - 1)
|
| 107 |
+
|
| 108 |
+
# Set seed for reproducibility
|
| 109 |
+
if seed is not None:
|
| 110 |
+
torch.manual_seed(seed)
|
| 111 |
+
np.random.seed(seed)
|
| 112 |
+
if model_cache.device == 'cuda':
|
| 113 |
+
torch.cuda.manual_seed(seed)
|
| 114 |
+
torch.cuda.manual_seed_all(seed)
|
| 115 |
+
|
| 116 |
+
# Determine scale schedule based on aspect ratio
|
| 117 |
+
h_div_w_template = h_div_w_templates[
|
| 118 |
+
np.argmin(np.abs(h_div_w_templates - aspect_ratio))
|
| 119 |
+
]
|
| 120 |
+
scale_schedule = dynamic_resolution_h_w[h_div_w_template][pn]['scales']
|
| 121 |
+
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
|
| 122 |
+
|
| 123 |
+
progress(0.1, desc="Encoding prompt...")
|
| 124 |
+
start_time = time.time()
|
| 125 |
+
|
| 126 |
+
progress(0.3, desc="Generating image (this may take a while)...")
|
| 127 |
+
|
| 128 |
+
# Generate image
|
| 129 |
+
img_np = generate_image(
|
| 130 |
+
model_cache.infinity_model,
|
| 131 |
+
model_cache.vae,
|
| 132 |
+
model_cache.text_tokenizer,
|
| 133 |
+
model_cache.text_encoder,
|
| 134 |
+
prompt,
|
| 135 |
+
cfg_scale=cfg_scale,
|
| 136 |
+
tau=tau,
|
| 137 |
+
seed=seed,
|
| 138 |
+
scale_schedule=scale_schedule,
|
| 139 |
+
vae_type=32,
|
| 140 |
+
device=model_cache.device
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
progress(0.9, desc="Converting to PIL Image...")
|
| 144 |
+
|
| 145 |
+
# Convert to PIL Image (RGB)
|
| 146 |
+
img_np = img_np.cpu().numpy()
|
| 147 |
+
# OpenCV uses BGR, convert to RGB
|
| 148 |
+
img_rgb = img_np[:, :, ::-1]
|
| 149 |
+
pil_image = Image.fromarray(img_rgb.astype(np.uint8))
|
| 150 |
+
|
| 151 |
+
elapsed_time = time.time() - start_time
|
| 152 |
+
|
| 153 |
+
# Get resolution
|
| 154 |
+
h, w = img_np.shape[:2]
|
| 155 |
+
|
| 156 |
+
info = f"""✓ Generation complete!
|
| 157 |
+
|
| 158 |
+
**Time**: {elapsed_time:.2f}s
|
| 159 |
+
**Resolution**: {w}x{h}
|
| 160 |
+
**Seed**: {seed}
|
| 161 |
+
**CFG Scale**: {cfg_scale}
|
| 162 |
+
**Tau**: {tau}
|
| 163 |
+
**Aspect Ratio**: {aspect_ratio:.2f}
|
| 164 |
+
**PN**: {pn}"""
|
| 165 |
+
|
| 166 |
+
progress(1.0, desc="Done!")
|
| 167 |
+
|
| 168 |
+
return pil_image, info
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
import traceback
|
| 172 |
+
error_msg = f"❌ Error during generation:\n{str(e)}\n\n{traceback.format_exc()}"
|
| 173 |
+
return None, error_msg
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def create_ui():
|
| 177 |
+
"""
|
| 178 |
+
Create Gradio UI
|
| 179 |
+
"""
|
| 180 |
+
# Create Blocks without theme for compatibility with older Gradio versions
|
| 181 |
+
with gr.Blocks(title="Infinity-2B GGUF Generator") as demo:
|
| 182 |
+
gr.Markdown("# 🎨 Infinity-2B GGUF Image Generator")
|
| 183 |
+
|
| 184 |
+
# Model paths banner at the top
|
| 185 |
+
with gr.Row():
|
| 186 |
+
infinity_gguf = gr.Textbox(
|
| 187 |
+
label="Infinity-2B GGUF",
|
| 188 |
+
value="infinity_2b_reg_Q8_0.gguf",
|
| 189 |
+
scale=2
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
t5_gguf = gr.Textbox(
|
| 193 |
+
label="T5 GGUF",
|
| 194 |
+
value="flan-t5-xl-encoder-Q8_0.gguf",
|
| 195 |
+
scale=2
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
vae_path = gr.Textbox(
|
| 199 |
+
label="VAE Checkpoint",
|
| 200 |
+
value="Infinity/infinity_vae_d32_reg.pth",
|
| 201 |
+
scale=2
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
pn_load = gr.Dropdown(
|
| 205 |
+
label="Resolution Preset",
|
| 206 |
+
choices=['0.06M', '0.25M', '1M'],
|
| 207 |
+
value='1M',
|
| 208 |
+
scale=1
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
load_btn = gr.Button("🚀 Load Models", variant="primary", scale=1)
|
| 212 |
+
|
| 213 |
+
load_status = gr.Textbox(label="Status", interactive=False, show_label=False)
|
| 214 |
+
|
| 215 |
+
# Main content area
|
| 216 |
+
with gr.Row():
|
| 217 |
+
# Left column: Generation settings
|
| 218 |
+
with gr.Column(scale=1):
|
| 219 |
+
gr.Markdown("### Generation Settings")
|
| 220 |
+
|
| 221 |
+
prompt = gr.Textbox(
|
| 222 |
+
label="Prompt",
|
| 223 |
+
placeholder="Describe the image you want to generate...",
|
| 224 |
+
value="an astronaut riding a horse on the moon",
|
| 225 |
+
lines=3
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
with gr.Row():
|
| 229 |
+
cfg_scale = gr.Slider(
|
| 230 |
+
minimum=1.0,
|
| 231 |
+
maximum=10.0,
|
| 232 |
+
value=3.0,
|
| 233 |
+
step=0.5,
|
| 234 |
+
label="CFG Scale",
|
| 235 |
+
info="Higher = stronger prompt adherence"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
tau = gr.Slider(
|
| 239 |
+
minimum=0.1,
|
| 240 |
+
maximum=1.0,
|
| 241 |
+
value=0.5,
|
| 242 |
+
step=0.05,
|
| 243 |
+
label="Tau (Temperature)",
|
| 244 |
+
info="Lower = more deterministic"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
with gr.Row():
|
| 248 |
+
aspect_ratio = gr.Slider(
|
| 249 |
+
minimum=0.5,
|
| 250 |
+
maximum=2.0,
|
| 251 |
+
value=1.0,
|
| 252 |
+
step=0.1,
|
| 253 |
+
label="Aspect Ratio (H/W)",
|
| 254 |
+
info="1.0 = square, >1.0 = portrait, <1.0 = landscape"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
pn = gr.Dropdown(
|
| 258 |
+
label="Resolution Preset",
|
| 259 |
+
choices=['0.06M', '0.25M', '1M'],
|
| 260 |
+
value='1M',
|
| 261 |
+
info="Higher = better quality but slower"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
with gr.Row():
|
| 265 |
+
seed = gr.Number(
|
| 266 |
+
label="Seed",
|
| 267 |
+
value=42,
|
| 268 |
+
precision=0,
|
| 269 |
+
info="For reproducible results"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
use_random_seed = gr.Checkbox(
|
| 273 |
+
label="Random Seed",
|
| 274 |
+
value=False,
|
| 275 |
+
info="Generate random seed each time"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
generate_btn = gr.Button("✨ Generate Image", variant="primary", size="lg")
|
| 279 |
+
|
| 280 |
+
# Right column: Output
|
| 281 |
+
with gr.Column(scale=1):
|
| 282 |
+
output_image = gr.Image(
|
| 283 |
+
label="Generated Image",
|
| 284 |
+
type="pil",
|
| 285 |
+
height=600
|
| 286 |
+
)
|
| 287 |
+
output_info = gr.Markdown("Generate an image to see details here.")
|
| 288 |
+
|
| 289 |
+
# Wire up events
|
| 290 |
+
load_btn.click(
|
| 291 |
+
fn=load_models,
|
| 292 |
+
inputs=[infinity_gguf, t5_gguf, vae_path, pn_load],
|
| 293 |
+
outputs=[load_status]
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
generate_btn.click(
|
| 297 |
+
fn=generate_image_gradio,
|
| 298 |
+
inputs=[prompt, cfg_scale, tau, seed, aspect_ratio, pn, use_random_seed],
|
| 299 |
+
outputs=[output_image, output_info]
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return demo
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def main():
|
| 306 |
+
parser = argparse.ArgumentParser(description='Infinity-2B GGUF Gradio Web UI')
|
| 307 |
+
parser.add_argument('--share', action='store_true', help='Create a public share link')
|
| 308 |
+
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Server name')
|
| 309 |
+
parser.add_argument('--server-port', type=int, default=7860, help='Server port')
|
| 310 |
+
parser.add_argument('--autoload', action='store_true', help='Auto-load models on startup')
|
| 311 |
+
parser.add_argument('--infinity-gguf', type=str, default='infinity_2b_reg_Q8_0.gguf')
|
| 312 |
+
parser.add_argument('--t5-gguf', type=str, default='flan-t5-xl-encoder-Q8_0.gguf')
|
| 313 |
+
parser.add_argument('--vae-path', type=str, default='Infinity/infinity_vae_d32_reg.pth')
|
| 314 |
+
|
| 315 |
+
args = parser.parse_args()
|
| 316 |
+
|
| 317 |
+
# Auto-load models if requested
|
| 318 |
+
if args.autoload:
|
| 319 |
+
print("Auto-loading models...")
|
| 320 |
+
load_models(args.infinity_gguf, args.t5_gguf, args.vae_path)
|
| 321 |
+
|
| 322 |
+
# Create and launch UI
|
| 323 |
+
demo = create_ui()
|
| 324 |
+
|
| 325 |
+
print("\n" + "="*70)
|
| 326 |
+
print("Starting Infinity-2B GGUF Web UI")
|
| 327 |
+
print("="*70)
|
| 328 |
+
print(f"Server: http://{args.server_name}:{args.server_port}")
|
| 329 |
+
if args.share:
|
| 330 |
+
print("Creating public share link...")
|
| 331 |
+
print("="*70 + "\n")
|
| 332 |
+
|
| 333 |
+
demo.launch(
|
| 334 |
+
server_name=args.server_name,
|
| 335 |
+
server_port=args.server_port,
|
| 336 |
+
share=args.share,
|
| 337 |
+
inbrowser=True
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
if __name__ == '__main__':
|
| 342 |
+
main()
|
infinity_2b_reg_Q8_0.gguf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:747220c5030d342f0195f34eb5c21ebb75b2bb855df96a848544be29f00326bc
|
| 3 |
+
size 2374494496
|
infinity_gguf_utils.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GGUF utilities for Infinity model inference
|
| 4 |
+
Includes GGUFParameter, dequantization functions, and GGUFLinear layer
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
# Monkey patch for NumPy 2.0 compatibility (must be done before importing gguf)
|
| 10 |
+
if not hasattr(np.ndarray, 'newbyteorder'):
|
| 11 |
+
def newbyteorder(self, new_order):
|
| 12 |
+
return self.view(self.dtype.newbyteorder(new_order))
|
| 13 |
+
np.ndarray.newbyteorder = newbyteorder
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import gguf
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Dequantization constants
|
| 22 |
+
QK_K = 256
|
| 23 |
+
K_SCALE_SIZE = 12
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def to_uint32(x):
|
| 27 |
+
"""Convert bytes to uint32"""
|
| 28 |
+
x = x.view(torch.uint8).to(torch.int32)
|
| 29 |
+
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def split_block_dims(blocks, *args):
|
| 33 |
+
"""Split block dimensions"""
|
| 34 |
+
n_max = blocks.shape[1]
|
| 35 |
+
dims = list(args) + [n_max - sum(args)]
|
| 36 |
+
return torch.split(blocks, dims, dim=1)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
|
| 40 |
+
"""Dequantize Q8_0 blocks"""
|
| 41 |
+
d, x = split_block_dims(blocks, 2)
|
| 42 |
+
d = d.view(torch.float16).to(dtype)
|
| 43 |
+
x = x.view(torch.int8)
|
| 44 |
+
return d * x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
|
| 48 |
+
"""Dequantize Q6_K blocks"""
|
| 49 |
+
n_blocks = blocks.shape[0]
|
| 50 |
+
ql, qh, scales, d = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
|
| 51 |
+
|
| 52 |
+
scales = scales.view(torch.int8).to(dtype)
|
| 53 |
+
d = d.view(torch.float16).to(dtype)
|
| 54 |
+
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
| 55 |
+
|
| 56 |
+
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
| 57 |
+
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
| 58 |
+
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
| 59 |
+
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
|
| 60 |
+
q = (ql | (qh << 4)).to(torch.int8) - 32
|
| 61 |
+
q = q.reshape((n_blocks, QK_K // 16, -1))
|
| 62 |
+
|
| 63 |
+
return (d * q).reshape((n_blocks, QK_K))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_scale_min(scales):
|
| 67 |
+
"""Extract scale and min from packed data"""
|
| 68 |
+
n_blocks = scales.shape[0]
|
| 69 |
+
scales = scales.view(torch.uint8)
|
| 70 |
+
scales = scales.reshape((n_blocks, 3, 4))
|
| 71 |
+
|
| 72 |
+
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
| 73 |
+
|
| 74 |
+
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
| 75 |
+
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
| 76 |
+
|
| 77 |
+
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
|
| 81 |
+
"""Dequantize Q5_K blocks"""
|
| 82 |
+
n_blocks = blocks.shape[0]
|
| 83 |
+
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
|
| 84 |
+
|
| 85 |
+
d = d.view(torch.float16).to(dtype)
|
| 86 |
+
dmin = dmin.view(torch.float16).to(dtype)
|
| 87 |
+
|
| 88 |
+
sc, m = get_scale_min(scales)
|
| 89 |
+
|
| 90 |
+
d = (d * sc).reshape((n_blocks, -1, 1))
|
| 91 |
+
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
| 92 |
+
|
| 93 |
+
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
| 94 |
+
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1))
|
| 95 |
+
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
| 96 |
+
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
|
| 97 |
+
q = ql | (qh << 4)
|
| 98 |
+
|
| 99 |
+
return (d * q - dm).reshape((n_blocks, QK_K))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
|
| 103 |
+
"""Dequantize Q4_K blocks"""
|
| 104 |
+
n_blocks = blocks.shape[0]
|
| 105 |
+
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
| 106 |
+
d = d.view(torch.float16).to(dtype)
|
| 107 |
+
dmin = dmin.view(torch.float16).to(dtype)
|
| 108 |
+
|
| 109 |
+
sc, m = get_scale_min(scales)
|
| 110 |
+
|
| 111 |
+
d = (d * sc).reshape((n_blocks, -1, 1))
|
| 112 |
+
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
| 113 |
+
|
| 114 |
+
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
| 115 |
+
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
| 116 |
+
|
| 117 |
+
return (d * qs - dm).reshape((n_blocks, QK_K))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
|
| 121 |
+
"""Dequantize BF16 blocks"""
|
| 122 |
+
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Mapping of quantization types to dequantization functions
|
| 126 |
+
GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
|
| 127 |
+
DEQUANTIZE_FUNCTIONS = {
|
| 128 |
+
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
|
| 129 |
+
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
| 130 |
+
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
|
| 131 |
+
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
|
| 132 |
+
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _quant_shape_from_byte_shape(shape, type_size, block_size):
|
| 137 |
+
"""Calculate dequantized shape from quantized byte shape"""
|
| 138 |
+
return (*shape[:-1], shape[-1] // type_size * block_size)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def dequantize_gguf_tensor(tensor, target_dtype=None):
|
| 142 |
+
"""
|
| 143 |
+
Dequantize a GGUF tensor to regular torch tensor
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
tensor: GGUFParameter or regular tensor
|
| 147 |
+
target_dtype: Target dtype for output (default: float32)
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Regular torch tensor
|
| 151 |
+
"""
|
| 152 |
+
# If not quantized, just return the tensor
|
| 153 |
+
if not hasattr(tensor, "quant_type"):
|
| 154 |
+
return tensor.to(target_dtype) if target_dtype else tensor
|
| 155 |
+
|
| 156 |
+
quant_type = tensor.quant_type
|
| 157 |
+
|
| 158 |
+
# If F32 or F16, just convert normally
|
| 159 |
+
if quant_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
| 160 |
+
return tensor.to(target_dtype) if target_dtype else tensor
|
| 161 |
+
|
| 162 |
+
# Get dequantization function
|
| 163 |
+
if quant_type not in DEQUANTIZE_FUNCTIONS:
|
| 164 |
+
raise ValueError(f"Unsupported quantization type: {quant_type}")
|
| 165 |
+
|
| 166 |
+
dequant_fn = DEQUANTIZE_FUNCTIONS[quant_type]
|
| 167 |
+
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
| 168 |
+
|
| 169 |
+
# Prepare tensor for dequantization
|
| 170 |
+
tensor_bytes = tensor.view(torch.uint8)
|
| 171 |
+
shape = _quant_shape_from_byte_shape(tensor_bytes.shape, type_size, block_size)
|
| 172 |
+
|
| 173 |
+
n_blocks = tensor_bytes.numel() // type_size
|
| 174 |
+
blocks = tensor_bytes.reshape((n_blocks, type_size))
|
| 175 |
+
|
| 176 |
+
# Dequantize
|
| 177 |
+
dtype = target_dtype if target_dtype else torch.float32
|
| 178 |
+
dequant = dequant_fn(blocks, block_size, type_size, dtype=dtype)
|
| 179 |
+
dequant = dequant.reshape(shape)
|
| 180 |
+
|
| 181 |
+
return dequant
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class GGUFParameter(torch.nn.Parameter):
|
| 185 |
+
"""
|
| 186 |
+
Custom Parameter class for GGUF quantized tensors
|
| 187 |
+
Stores quantization metadata alongside the data
|
| 188 |
+
"""
|
| 189 |
+
def __new__(cls, data, requires_grad=False, quant_type=None):
|
| 190 |
+
data = data if data is not None else torch.empty(0)
|
| 191 |
+
# Store byte shape before creating parameter
|
| 192 |
+
byte_shape = data.shape
|
| 193 |
+
self = torch.Tensor._make_subclass(cls, data, requires_grad)
|
| 194 |
+
self.quant_type = quant_type
|
| 195 |
+
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
| 196 |
+
self.quant_shape = _quant_shape_from_byte_shape(byte_shape, type_size, block_size)
|
| 197 |
+
return self
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def shape(self):
|
| 201 |
+
"""Return the dequantized shape instead of byte shape"""
|
| 202 |
+
if hasattr(self, 'quant_shape'):
|
| 203 |
+
return self.quant_shape
|
| 204 |
+
# Fallback: get shape from parent class without causing recursion
|
| 205 |
+
return object.__getattribute__(self, 'data').shape if hasattr(self, 'data') else torch.Size()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix=""):
|
| 209 |
+
"""
|
| 210 |
+
Replace nn.Linear layers with GGUF Linear layers for on-the-fly dequantization
|
| 211 |
+
Based on ComfyUI-WanVideoWrapper implementation
|
| 212 |
+
"""
|
| 213 |
+
def _should_convert_to_gguf(state_dict, prefix):
|
| 214 |
+
weight_key = prefix + "weight"
|
| 215 |
+
return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter)
|
| 216 |
+
|
| 217 |
+
has_children = list(model.children())
|
| 218 |
+
if not has_children:
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
try:
|
| 222 |
+
from accelerate import init_empty_weights
|
| 223 |
+
use_accelerate = True
|
| 224 |
+
except ImportError:
|
| 225 |
+
use_accelerate = False
|
| 226 |
+
|
| 227 |
+
for name, module in model.named_children():
|
| 228 |
+
module_prefix = prefix + name + "."
|
| 229 |
+
_replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix)
|
| 230 |
+
|
| 231 |
+
if (
|
| 232 |
+
isinstance(module, nn.Linear)
|
| 233 |
+
and not isinstance(module, GGUFLinear)
|
| 234 |
+
and _should_convert_to_gguf(state_dict, module_prefix)
|
| 235 |
+
):
|
| 236 |
+
# Get correct dimensions from the GGUF parameter shape
|
| 237 |
+
weight_param = state_dict[module_prefix + "weight"]
|
| 238 |
+
if hasattr(weight_param, 'quant_shape'):
|
| 239 |
+
out_features, in_features = weight_param.quant_shape
|
| 240 |
+
else:
|
| 241 |
+
out_features, in_features = weight_param.shape
|
| 242 |
+
|
| 243 |
+
# Check if this is a custom Linear subclass with a custom forward method
|
| 244 |
+
module_type = type(module)
|
| 245 |
+
has_custom_forward = (
|
| 246 |
+
module_type != nn.Linear and
|
| 247 |
+
hasattr(module_type, 'forward') and
|
| 248 |
+
module_type.forward is not nn.Linear.forward
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
if has_custom_forward:
|
| 252 |
+
# For custom Linear subclasses (like SharedAdaLin), create wrapped forward
|
| 253 |
+
from types import MethodType
|
| 254 |
+
|
| 255 |
+
def wrapped_forward(self, *args, **kwargs):
|
| 256 |
+
input_tensor = args[0] if args else None
|
| 257 |
+
if input_tensor is not None and hasattr(input_tensor, 'dtype'):
|
| 258 |
+
target_dtype = input_tensor.dtype if input_tensor.dtype in [torch.float16, torch.bfloat16, torch.float32] else compute_dtype
|
| 259 |
+
else:
|
| 260 |
+
target_dtype = compute_dtype
|
| 261 |
+
|
| 262 |
+
# Dequantize weights
|
| 263 |
+
dequant_weight = dequantize_gguf_tensor(self.weight, target_dtype=target_dtype)
|
| 264 |
+
dequant_bias = None
|
| 265 |
+
if self.bias is not None:
|
| 266 |
+
if isinstance(self.bias, GGUFParameter):
|
| 267 |
+
dequant_bias = dequantize_gguf_tensor(self.bias, target_dtype=target_dtype)
|
| 268 |
+
else:
|
| 269 |
+
dequant_bias = self.bias
|
| 270 |
+
|
| 271 |
+
# Perform linear operation
|
| 272 |
+
import torch.nn.functional as F
|
| 273 |
+
linear_output = F.linear(input_tensor, dequant_weight, dequant_bias)
|
| 274 |
+
|
| 275 |
+
# Apply custom reshaping for SharedAdaLin
|
| 276 |
+
if module_type.__name__ == 'SharedAdaLin':
|
| 277 |
+
C = dequant_weight.shape[0] // 6
|
| 278 |
+
return linear_output.reshape(-1, 1, 6, C)
|
| 279 |
+
|
| 280 |
+
return linear_output
|
| 281 |
+
|
| 282 |
+
new_module = GGUFLinear(
|
| 283 |
+
in_features,
|
| 284 |
+
out_features,
|
| 285 |
+
module.bias is not None,
|
| 286 |
+
compute_dtype=compute_dtype,
|
| 287 |
+
)
|
| 288 |
+
new_module.forward = MethodType(wrapped_forward, new_module)
|
| 289 |
+
else:
|
| 290 |
+
# Standard GGUFLinear replacement
|
| 291 |
+
if use_accelerate:
|
| 292 |
+
with init_empty_weights():
|
| 293 |
+
new_module = GGUFLinear(
|
| 294 |
+
in_features,
|
| 295 |
+
out_features,
|
| 296 |
+
module.bias is not None,
|
| 297 |
+
compute_dtype=compute_dtype,
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
new_module = GGUFLinear(
|
| 301 |
+
in_features,
|
| 302 |
+
out_features,
|
| 303 |
+
module.bias is not None,
|
| 304 |
+
compute_dtype=compute_dtype,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
model._modules[name] = new_module
|
| 308 |
+
model._modules[name].source_cls = type(module)
|
| 309 |
+
model._modules[name].requires_grad_(False)
|
| 310 |
+
|
| 311 |
+
return model
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class GGUFLinear(nn.Linear):
|
| 315 |
+
"""
|
| 316 |
+
Custom Linear layer that dequantizes GGUF weights on-the-fly
|
| 317 |
+
Compatible with Infinity model architecture
|
| 318 |
+
"""
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
in_features: int,
|
| 322 |
+
out_features: int,
|
| 323 |
+
bias: bool = True,
|
| 324 |
+
device=None,
|
| 325 |
+
dtype=None,
|
| 326 |
+
compute_dtype=None,
|
| 327 |
+
):
|
| 328 |
+
super().__init__(in_features, out_features, bias, device, dtype)
|
| 329 |
+
self.compute_dtype = compute_dtype if compute_dtype else torch.float32
|
| 330 |
+
|
| 331 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 332 |
+
"""
|
| 333 |
+
Forward pass with on-the-fly dequantization
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
input: Input tensor
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
Output tensor after linear transformation
|
| 340 |
+
"""
|
| 341 |
+
# Dequantize weight to compute dtype or match input dtype
|
| 342 |
+
target_dtype = input.dtype if input.dtype in [torch.float16, torch.bfloat16, torch.float32] else self.compute_dtype
|
| 343 |
+
weight = dequantize_gguf_tensor(self.weight, target_dtype=target_dtype)
|
| 344 |
+
|
| 345 |
+
# Transpose weight for PyTorch (GGUF stores as (out, in) for some, (in, out) for others)
|
| 346 |
+
# For linear layers, assume GGUF stores as (out, in)
|
| 347 |
+
# weight = weight.t()
|
| 348 |
+
|
| 349 |
+
# Dequantize bias if present
|
| 350 |
+
bias = None
|
| 351 |
+
if self.bias is not None:
|
| 352 |
+
bias = dequantize_gguf_tensor(self.bias, target_dtype=target_dtype)
|
| 353 |
+
|
| 354 |
+
# Perform linear operation
|
| 355 |
+
return torch.nn.functional.linear(input, weight, bias)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def load_gguf_state_dict_with_params(gguf_path, device='cuda'):
|
| 359 |
+
"""
|
| 360 |
+
Load GGUF file and return state dict with GGUFParameters for quantized tensors
|
| 361 |
+
For use with _replace_with_gguf_linear
|
| 362 |
+
"""
|
| 363 |
+
from gguf import GGUFReader
|
| 364 |
+
reader = GGUFReader(gguf_path)
|
| 365 |
+
|
| 366 |
+
state_dict = {}
|
| 367 |
+
|
| 368 |
+
for tensor in reader.tensors:
|
| 369 |
+
torch_tensor = torch.from_numpy(np.array(tensor.data)).to(device)
|
| 370 |
+
|
| 371 |
+
# Check if quantized
|
| 372 |
+
is_quantized = tensor.tensor_type not in {
|
| 373 |
+
gguf.GGMLQuantizationType.F32,
|
| 374 |
+
gguf.GGMLQuantizationType.F16
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
if is_quantized:
|
| 378 |
+
# Keep as GGUFParameter for on-the-fly dequantization
|
| 379 |
+
param = GGUFParameter(torch_tensor, quant_type=tensor.tensor_type)
|
| 380 |
+
state_dict[tensor.name] = param
|
| 381 |
+
else:
|
| 382 |
+
# Already F32 or F16 - convert to regular tensor
|
| 383 |
+
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
| 384 |
+
torch_tensor = torch_tensor.view(*shape)
|
| 385 |
+
if tensor.tensor_type == gguf.GGMLQuantizationType.F32:
|
| 386 |
+
state_dict[tensor.name] = nn.Parameter(torch_tensor.float())
|
| 387 |
+
else:
|
| 388 |
+
state_dict[tensor.name] = nn.Parameter(torch_tensor.half())
|
| 389 |
+
|
| 390 |
+
return state_dict
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def load_gguf_state_dict(gguf_path):
|
| 394 |
+
"""
|
| 395 |
+
Load GGUF file and create state dict with GGUFParameters
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
gguf_path: Path to GGUF file
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
state_dict: Dictionary mapping tensor names to GGUFParameters or regular tensors
|
| 402 |
+
"""
|
| 403 |
+
from gguf import GGUFReader
|
| 404 |
+
|
| 405 |
+
reader = GGUFReader(gguf_path)
|
| 406 |
+
state_dict = {}
|
| 407 |
+
|
| 408 |
+
for tensor in reader.tensors:
|
| 409 |
+
# Check if tensor is quantized
|
| 410 |
+
is_quantized = tensor.tensor_type not in {
|
| 411 |
+
gguf.GGMLQuantizationType.F32,
|
| 412 |
+
gguf.GGMLQuantizationType.F16
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
# Create meta tensor with appropriate type
|
| 416 |
+
if is_quantized:
|
| 417 |
+
# For quantized tensors, create GGUFParameter
|
| 418 |
+
meta_tensor = torch.from_numpy(np.array(tensor.data)).to('cpu')
|
| 419 |
+
param = GGUFParameter(meta_tensor, quant_type=tensor.tensor_type)
|
| 420 |
+
state_dict[tensor.name] = param
|
| 421 |
+
else:
|
| 422 |
+
# For F32/F16, just load normally
|
| 423 |
+
state_dict[tensor.name] = torch.from_numpy(np.array(tensor.data)).to('cpu')
|
| 424 |
+
|
| 425 |
+
return state_dict
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def replace_linear_with_gguf(model, state_dict, compute_dtype=torch.float32):
|
| 429 |
+
"""
|
| 430 |
+
Recursively replace nn.Linear layers with GGUFLinear layers
|
| 431 |
+
where the corresponding weight in state_dict is a GGUFParameter
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
model: PyTorch model
|
| 435 |
+
state_dict: State dict with GGUFParameters
|
| 436 |
+
compute_dtype: Dtype to use for computation
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
Modified model with GGUFLinear layers
|
| 440 |
+
"""
|
| 441 |
+
for name, module in model.named_children():
|
| 442 |
+
# Recursively process children
|
| 443 |
+
replace_linear_with_gguf(module, state_dict, compute_dtype)
|
| 444 |
+
|
| 445 |
+
# Check if this is a Linear layer with quantized weights
|
| 446 |
+
if isinstance(module, nn.Linear):
|
| 447 |
+
weight_key = f"{get_module_prefix(model, name)}.weight"
|
| 448 |
+
|
| 449 |
+
if weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter):
|
| 450 |
+
# Replace with GGUFLinear
|
| 451 |
+
in_features = module.in_features
|
| 452 |
+
out_features = module.out_features
|
| 453 |
+
has_bias = module.bias is not None
|
| 454 |
+
|
| 455 |
+
gguf_linear = GGUFLinear(
|
| 456 |
+
in_features,
|
| 457 |
+
out_features,
|
| 458 |
+
bias=has_bias,
|
| 459 |
+
compute_dtype=compute_dtype
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Copy the module to the model
|
| 463 |
+
setattr(model, name, gguf_linear)
|
| 464 |
+
|
| 465 |
+
return model
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def get_module_prefix(model, module_name):
|
| 469 |
+
"""Helper to get the full prefix for a module"""
|
| 470 |
+
# This is a simplified version - you may need to adjust based on your model structure
|
| 471 |
+
return module_name
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
if __name__ == "__main__":
|
| 475 |
+
# Test dequantization
|
| 476 |
+
print("GGUF utilities loaded successfully!")
|
| 477 |
+
print(f"Supported quantization types: {list(DEQUANTIZE_FUNCTIONS.keys())}")
|