Spaces:
Runtime error
Runtime error
File size: 36,945 Bytes
873b6ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 | # Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange, repeat
from inference.common import Modality, VarlenHandler, is_hopper_arch
from inference.infra.parallelism import ulysses_scheduler
from magi_compiler import magi_compile
from magi_compiler.api import magi_register_custom_op
from magi_compiler.config import CompileConfig
from torch import Tensor
from torch.nn import Parameter
@dataclass
class FFAHandler:
q_ranges: torch.Tensor
k_ranges: torch.Tensor
max_seqlen_q: int
max_seqlen_k: int
attn_type_map: torch.Tensor
softmax_scale: float
# Define the MLP activation type
class MLPActivationType(Enum):
"""Enumeration of supported activation functions for MLP"""
SWIGLU7 = "swiglu7"
GELU7 = "gelu7"
def swiglu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None):
out_dtype = x.dtype if out_dtype is None else out_dtype
x = x.to(torch.float32)
x_glu, x_linear = x[..., ::2], x[..., 1::2]
# Clamp the input values
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
# Note we add an extra bias of 1 to the linear layer (from GPT-OSS)
return (out_glu * (x_linear + 1)).to(out_dtype)
def gelu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None):
out_dtype = x.dtype if out_dtype is None else out_dtype
x = x.to(torch.float32)
x_glu = x
# Clamp the input values
x_glu = x_glu.clamp(min=None, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
# Note we add an extra bias of 1 to the linear layer
return out_glu.to(out_dtype)
def create_activation_func(activation_type: MLPActivationType) -> Callable:
match activation_type:
case MLPActivationType.SWIGLU7:
return swiglu7
case MLPActivationType.GELU7:
return gelu7
case _:
raise ValueError(f"Unknown activation type: {activation_type}")
class ModalityDispatcher:
permuted_modality_mapping: torch.Tensor
group_size: torch.Tensor
group_size_cpu: list[int]
num_modalities: int
def __init__(self, modality_mapping: torch.Tensor, num_modalities: int):
"""
Initialize dispatcher.
This runs once during object construction and precomputes all mappings.
"""
self.modality_mapping = modality_mapping
self.num_modalities = num_modalities
self.permuted_modality_mapping = self._precompute_permute_mapping(modality_mapping)
self.group_size = torch.bincount(self.permuted_modality_mapping, minlength=num_modalities).to(torch.int32)
self.group_size_cpu: list[int] = [int(x) for x in self.group_size.to("cpu").tolist()]
def _precompute_permute_mapping(self, modality_mapping):
# 1. Compute forward and inverse permutation mappings.
# argsort is an efficient O(N log N) operation.
self.permute_mapping = torch.argsort(modality_mapping)
self.inv_permute_mapping = torch.argsort(self.permute_mapping)
# 2. Compute group size for each modality.
# bincount is highly efficient for counting.
permuted_modality_mapping = modality_mapping[self.permute_mapping]
return permuted_modality_mapping
def dispatch(self, x: torch.Tensor) -> list[torch.Tensor]:
grouped_tensors = torch.split(x, self.group_size_cpu, dim=0)
return list(grouped_tensors)
def undispatch(self, *processed_groups: list[torch.Tensor]) -> torch.Tensor:
return torch.cat(processed_groups, dim=0)
@staticmethod
def permute(x: torch.Tensor, permute_mapping: torch.Tensor) -> torch.Tensor:
"""Apply forward permutation to tensor."""
return x[permute_mapping]
@staticmethod
def inv_permute(x: torch.Tensor, inv_permute_mapping: torch.Tensor) -> torch.Tensor:
"""Apply inverse permutation to tensor."""
return x[inv_permute_mapping]
def freq_bands(
num_bands: int, temperature: float = 10000.0, step: int = 2, device: Optional[torch.device] = None
) -> torch.Tensor:
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
bands = 1.0 / (temperature**exp)
return bands
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1)
class ElementWiseFourierEmbed(nn.Module):
def __init__(
self,
dim: int,
max_res: int = 224,
temperature: float = 10000.0,
in_pixels: bool = True,
linear_bands: bool = False,
learnable: bool = False,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
"""
Args:
dim: Output feature dimension, total channels, must be divisible by 6
max_res: Max pixel-frequency resolution for pixel-domain bands
temperature: Temperature in inverse-frequency mode
in_pixels: True -> pixel-frequency bands, False -> inverse-frequency bands
linear_bands: Whether pixel-frequency bands are linearly spaced
learnable: Whether frequency bands are trainable
"""
super().__init__()
self.dim = dim
self.in_pixels = in_pixels
self.learnable = learnable
self.temperature = temperature
self.max_res = max_res
self.linear_bands = linear_bands
self.device = device
self.dtype = dtype
# Make frequency bands trainable or register as buffer
bands = self.get_default_bands()
if self.learnable:
self.bands = nn.Parameter(bands)
else:
self.register_buffer("bands", bands)
def forward(self, coords: torch.Tensor) -> torch.Tensor:
"""
Args:
coords: [L,9], column order (time, row, col, T, H, W, ref_T, ref_H, ref_W)
Returns:
emb: [L, dim] element-wise Fourier embedding
"""
# Use slicing instead of unbind + stack to reduce intermediates
coords_xyz = coords[:, :3] # [L,3] -> (t, h, w)
sizes = coords[:, 3:6] # [L,3] -> (T, H, W)
refs = coords[:, 6:9] # [L,3] -> (ref_T, ref_H, ref_W)
# Compute scale factors
scales = (refs - 1) / (sizes - 1) # [L,3]
# NOTE: if both ref and size are 1, scale is fixed to 1; otherwise invalid
scales[(refs == 1) & (sizes == 1)] = 1
assert not scales.isnan().any(), "scales has nan"
assert not scales.isinf().any(), "scales has inf"
# Center alignment: apply to h,w only (not time)
centers = (sizes - 1) / 2 # [L,3]
centers[:, 0] = 0 # Do not center the time dimension
coords_xyz = coords_xyz - centers # [L,3]
# Project to frequency bands in one shot: [L,3,B]
proj = coords_xyz.unsqueeze(-1) * scales.unsqueeze(-1) * self.bands
# Compute sin & cos and concatenate
sin_proj = proj.sin() # [L,3,B]
cos_proj = proj.cos()
return torch.cat((sin_proj, cos_proj), dim=1).flatten(1)
def reset_parameters(self):
bands = self.get_default_bands()
self.bands.copy_(bands)
def get_default_bands(self):
if self.in_pixels:
raise NotImplementedError("in_pixels are not implemented yet")
else:
bands = freq_bands(self.dim // 8, temperature=self.temperature, step=1, device=self.device).to(self.dtype)
return bands
class MultiModalityRMSNorm(nn.Module):
__constants__ = ["dim", "eps", "num_modality"]
dim: int
eps: float
num_modality: int
def __init__(self, dim: int, eps: float = 1e-6, device: torch.device | None = None, num_modality: int = 1):
super().__init__()
self.dim = dim
self.eps = eps
self.num_modality = num_modality
self.weight = torch.nn.Parameter(torch.zeros(dim * num_modality, device=device, dtype=torch.float32))
if num_modality > 1:
self.forward = self.forward_multi_experts
else:
self.forward = self.forward_single_expert
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.weight)
def rms(self, x: torch.Tensor) -> torch.Tensor:
t, original_dtype = x.float(), x.dtype
t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
return t
def forward_multi_experts(self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher) -> torch.Tensor:
original_dtype = x.dtype
t = self.rms(x)
weight_chunked = self.weight.chunk(self.num_modality, dim=0)
t_list = modality_dispatcher.dispatch(t)
for i in range(self.num_modality):
t_list[i] = t_list[i] * (weight_chunked[i] + 1)
t = modality_dispatcher.undispatch(*t_list)
return t.to(original_dtype)
def forward_single_expert(self, x: torch.Tensor, modality_dispatcher: Optional[ModalityDispatcher] = None) -> torch.Tensor:
t, original_dtype = x.float(), x.dtype
t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
return (t * (self.weight + 1)).to(original_dtype)
class _BF16ComputeLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
output_dtype: Optional[torch.dtype],
compute_dtype: torch.dtype = torch.bfloat16,
):
# Convert input to specified input data type
input_cast = input.to(compute_dtype)
# Convert weight to computation data type
weight_cast = weight.to(compute_dtype)
# Perform linear operation
output = torch.matmul(input_cast, weight_cast.t())
# Add bias if present
if bias is not None:
bias_cast = bias.to(compute_dtype)
output = output + bias_cast
else:
bias_cast = None
# Convert output to specified output data type
return output.to(output_dtype)
class BaseLinear(nn.Module):
__constants__ = ["in_features", "out_features", "num_layers", "num_experts"]
in_features: int
out_features: int
num_layers_for_initialization: int
num_experts: int
weight: Tensor
def __init__(
self, in_features, out_features, num_layers_for_initialization, num_experts, bias=True, device=None, dtype=None
):
super().__init__()
factory_kwargs = {"device": device, "dtype": torch.bfloat16}
self.in_features = in_features
self.out_features = out_features
self.num_layers_for_initialization = num_layers_for_initialization
self.num_experts = num_experts
self.use_bias = bias
self.weight = Parameter(torch.empty((out_features * num_experts, in_features), **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(out_features * num_experts, **factory_kwargs))
else:
self.register_parameter("bias", None)
def forward(
self,
input: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
modality_dispatcher: Optional[ModalityDispatcher] = None,
) -> torch.Tensor:
output_dtype = input.dtype if output_dtype is None else output_dtype
return _BF16ComputeLinear.apply(input, self.weight, self.bias, output_dtype, torch.bfloat16)
class NativeMoELinear(BaseLinear):
def forward(
self,
input: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
modality_dispatcher: Optional[ModalityDispatcher] = None,
) -> torch.Tensor:
output_dtype = input.dtype if output_dtype is None else output_dtype
input_list = modality_dispatcher.dispatch(input) # type: ignore
weight_chunked = self.weight.chunk(self.num_experts, dim=0)
if self.bias is not None:
bias_chunked = self.bias.chunk(self.num_experts, dim=0)
for i in range(self.num_experts):
input_list[i] = _BF16ComputeLinear.apply(
input_list[i],
weight_chunked[i],
bias_chunked[i] if self.bias is not None else None,
output_dtype,
torch.bfloat16,
)
return modality_dispatcher.undispatch(*input_list) # type: ignore
def create_linear(
in_features, out_features, num_layers=1, num_experts=1, bias=True, device=None, dtype=None
) -> BaseLinear | NativeMoELinear:
if num_experts == 1:
return BaseLinear(in_features, out_features, num_layers, num_experts, bias, device, dtype)
else:
return NativeMoELinear(in_features, out_features, num_layers, num_experts, bias, device, dtype)
HAS_MAGI_ATTENTION = importlib.util.find_spec("magi_attention") is not None
HAS_FA3 = importlib.util.find_spec("flash_attn_interface") is not None
@magi_register_custom_op(name="infra::flash_attn_func", is_subgraph_boundary=True)
def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
if HAS_FA3 and is_hopper_arch():
from flash_attn_interface import flash_attn_func as fa3_flash_attn_func
return fa3_flash_attn_func(query, key, value)
else:
from flash_attn.flash_attn_interface import flash_attn_func as fa2_flash_attn_func
return fa2_flash_attn_func(query, key, value)
def _split_q_range_with_no_overlap(
q_ranges: torch.Tensor, k_ranges: torch.Tensor
) -> Tuple[List[List[int]], List[List[List[int]]]]:
range_boundary = torch.unique(q_ranges, sorted=True).tolist()
candidates = [[start, end, []] for start, end in zip(range_boundary[:-1], range_boundary[1:])]
q_ranges = q_ranges.tolist()
k_ranges = k_ranges.tolist()
for q_range, k_range in zip(q_ranges, k_ranges):
q_start, q_end = q_range
for q_range_cand in candidates:
if q_start <= q_range_cand[0] and q_range_cand[1] <= q_end:
q_range_cand[2].append(k_range)
q_ranges_out = []
k_ranges_out = []
for q_range_cand in candidates:
if len(q_range_cand[2]) > 0:
q_ranges_out.append(q_range_cand[0:2])
k_ranges_out.append(q_range_cand[2])
return q_ranges_out, k_ranges_out
def _flash_attn_with_correction(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: List[List[int]], k_range_list: List[List[List[int]]]
):
output = torch.zeros_like(query)
output_lse = torch.zeros((query.shape[0], query.shape[1]), dtype=torch.float32, device=query.device)
from flash_attn.flash_attn_interface import flash_attn_func
for q_range, k_ranges in zip(q_ranges, k_range_list):
q_start, q_end = q_range
qo_out, qo_lse = None, None
for k_range in k_ranges:
k_start, k_end = k_range
cur_qo_out, cur_qo_lse, _ = flash_attn_func(
query[q_start:q_end].unsqueeze(0),
key[k_start:k_end].unsqueeze(0),
value[k_start:k_end].unsqueeze(0),
return_attn_probs=True,
)
cur_qo_out, cur_qo_lse = cur_qo_out.squeeze(0), cur_qo_lse.squeeze(0)
if qo_out is None:
qo_out = cur_qo_out
qo_lse = cur_qo_lse
else:
qo_lse[qo_lse == torch.inf] = -torch.inf
cur_qo_lse[cur_qo_lse == torch.inf] = -torch.inf
max_lse = torch.max(qo_lse, cur_qo_lse)
qo_se, cur_qo_se = torch.exp(qo_lse - max_lse), torch.exp(cur_qo_lse - max_lse)
sum_se = qo_se + cur_qo_se
qo_scale, cur_qo_scale = qo_se / sum_se, cur_qo_se / sum_se
qo_out = qo_out * qo_scale.permute(1, 0).unsqueeze(-1) + cur_qo_out * cur_qo_scale.permute(1, 0).unsqueeze(-1)
qo_lse = torch.log(sum_se) + max_lse
output[q_start:q_end] = qo_out
output_lse[q_start:q_end, :] = qo_lse.permute(1, 0)
return output, output_lse
def _custom_flex_flash_attn_func(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor, **kwargs
):
q_ranges, k_range_list = _split_q_range_with_no_overlap(q_ranges, k_ranges)
return _flash_attn_with_correction(query, key, value, q_ranges, k_range_list)
def _flex_flash_attn_func_infer_output_meta(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(query)
output_lse = torch.empty((query.shape[0], query.shape[1]), dtype=torch.float32, device=query.device)
return output, output_lse
@magi_register_custom_op(
name="infra::flex_flash_attn_func",
mutates_args=(),
infer_output_meta_fn=_flex_flash_attn_func_infer_output_meta,
is_subgraph_boundary=True,
)
def flex_flash_attn_func(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if HAS_MAGI_ATTENTION and is_hopper_arch():
from magi_attention.api import flex_flash_attn_func as magi_flex_flash_attn_func
return magi_flex_flash_attn_func(query, key, value, q_ranges, k_ranges)
else:
return _custom_flex_flash_attn_func(query, key, value, q_ranges, k_ranges)
def _attention_with_cp_infer_output_meta(q: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return torch.empty_like(q, dtype=torch.bfloat16).squeeze(0)
@magi_register_custom_op(
name="infra::flash_attn_with_cp",
mutates_args=(),
infer_output_meta_fn=_attention_with_cp_infer_output_meta,
is_subgraph_boundary=True,
)
def flash_attn_with_cp(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cp_split_sizes: List[int]) -> torch.Tensor:
q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
from inference.infra.distributed import get_cp_group, get_cp_world_size
from inference.infra.parallelism.all_to_all_primitive import batch_scatter_head_gather_seqlen, scatter_seqlen_gather_head
if get_cp_world_size() > 1:
q, k, v = batch_scatter_head_gather_seqlen([q.squeeze(0), k.squeeze(0), v.squeeze(0)], cp_split_sizes, get_cp_group())
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
self_attn_out = torch.ops.infra.flash_attn_func(q, k, v).squeeze(0)
if get_cp_world_size() > 1:
self_attn_out = scatter_seqlen_gather_head(self_attn_out, cp_split_sizes, get_cp_group(), async_op=False)
self_attn_out = rearrange(self_attn_out, "(cp sq) hn hd -> sq (cp hn) hd", cp=get_cp_world_size())
return self_attn_out
@magi_register_custom_op(
name="infra::flex_flash_attn_with_cp",
mutates_args=(),
infer_output_meta_fn=_attention_with_cp_infer_output_meta,
is_subgraph_boundary=True,
)
def flex_flash_attn_with_cp(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_ranges: torch.Tensor,
k_ranges: torch.Tensor,
cp_split_sizes: List[int],
) -> torch.Tensor:
q, k, v = q.to(torch.bfloat16).squeeze(0), k.to(torch.bfloat16).squeeze(0), v.to(torch.bfloat16).squeeze(0)
from inference.infra.distributed import get_cp_group, get_cp_world_size
from inference.infra.parallelism.all_to_all_primitive import batch_scatter_head_gather_seqlen, scatter_seqlen_gather_head
if get_cp_world_size() > 1:
q, k, v = batch_scatter_head_gather_seqlen([q, k, v], cp_split_sizes, get_cp_group())
out, _ = torch.ops.infra.flex_flash_attn_func(q, k, v, q_ranges=q_ranges, k_ranges=k_ranges)
if get_cp_world_size() > 1:
out = scatter_seqlen_gather_head(out, cp_split_sizes, get_cp_group(), async_op=False)
out = rearrange(out, "(cp sq) hn hd -> sq (cp hn) hd", cp=get_cp_world_size())
return out
@dataclass
class AttentionConfig:
hidden_size: int
num_heads_q: int
num_heads_kv: int
head_dim: int
params_dtype: torch.dtype
checkpoint_qk_layernorm_rope: bool
num_modality: int
num_layers: int
use_local_attn: bool = False
enable_attn_gating: bool = False
class Attention(torch.nn.Module):
config: AttentionConfig
def __init__(self, config: AttentionConfig):
super().__init__()
self.config = config
self.pre_norm = MultiModalityRMSNorm(config.hidden_size, eps=1e-6, num_modality=config.num_modality)
self.gating_size = config.num_heads_q if config.enable_attn_gating else 0
self.linear_qkv = create_linear(
config.hidden_size,
config.num_heads_q * config.head_dim + config.num_heads_kv * config.head_dim * 2 + self.gating_size,
num_experts=config.num_modality,
bias=False,
dtype=config.params_dtype,
num_layers=config.num_layers,
)
self.linear_proj = create_linear(
config.num_heads_q * config.head_dim,
config.hidden_size,
bias=False,
num_experts=config.num_modality,
dtype=config.params_dtype,
num_layers=config.num_layers,
)
self.q_norm = MultiModalityRMSNorm(config.head_dim, num_modality=config.num_modality)
self.k_norm = MultiModalityRMSNorm(config.head_dim, num_modality=config.num_modality)
self.q_size = config.num_heads_q * config.head_dim
self.kv_size = config.num_heads_kv * config.head_dim
def reset_parameters(self):
if hasattr(self.linear_proj, "reset_parameters_output_layer"):
self.linear_proj.reset_parameters_output_layer()
def forward(
self,
hidden_states: torch.Tensor,
rope: torch.Tensor,
permute_mapping: torch.Tensor,
inv_permute_mapping: torch.Tensor,
varlen_handler: VarlenHandler,
local_attn_handler: FFAHandler,
modality_dispatcher: ModalityDispatcher,
cp_split_sizes: List[int],
) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.bfloat16)
qkv: torch.Tensor = self.linear_qkv(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.float32)
q, k, v, g = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size, self.gating_size], dim=1)
q = q.view(-1, self.config.num_heads_q, self.config.head_dim)
k = k.view(-1, self.config.num_heads_kv, self.config.head_dim)
v = v.view(-1, self.config.num_heads_kv, self.config.head_dim)
g = g.view(k.shape[0], self.config.num_heads_q, -1)
q = self.q_norm(q, modality_dispatcher=modality_dispatcher)
k = self.k_norm(k, modality_dispatcher=modality_dispatcher)
q = ModalityDispatcher.inv_permute(q, inv_permute_mapping).unsqueeze(0)
k = ModalityDispatcher.inv_permute(k, inv_permute_mapping).unsqueeze(0)
v = ModalityDispatcher.inv_permute(v, inv_permute_mapping).unsqueeze(0)
sin_emb, cos_emb = rope.tensor_split(2, -1)
q = apply_rotary_emb_torch(q, cos_emb, sin_emb)
k = apply_rotary_emb_torch(k, cos_emb, sin_emb)
if self.config.use_local_attn:
self_attn_out = flex_flash_attn_with_cp(
q, k, v, local_attn_handler.q_ranges, local_attn_handler.k_ranges, cp_split_sizes
)
else:
self_attn_out = flash_attn_with_cp(q, k, v, cp_split_sizes)
self_attn_out = ModalityDispatcher.permute(self_attn_out, permute_mapping)
if self.config.enable_attn_gating:
self_attn_out = self_attn_out * torch.sigmoid(g)
self_attn_out = self_attn_out.view(-1, self.config.num_heads_q * self.config.head_dim).to(torch.bfloat16)
out = self.linear_proj(self_attn_out, modality_dispatcher=modality_dispatcher)
return out
@dataclass
class MLPConfig:
hidden_size: int
intermediate_size: int
activation_type: MLPActivationType
params_dtype: torch.dtype
num_modality: int = 1
num_layers: int = 1
gated_act: bool = False
class MLP(torch.nn.Module):
config: MLPConfig
def __init__(self, config: MLPConfig):
super().__init__()
num_experts = config.num_modality
self.pre_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=config.num_modality)
intermediate_size_up = config.intermediate_size * 2 if config.gated_act else config.intermediate_size
self.up_gate_proj = create_linear(
config.hidden_size,
intermediate_size_up,
bias=False,
dtype=config.params_dtype,
num_layers=config.num_layers,
num_experts=num_experts,
)
self.down_proj = create_linear(
config.intermediate_size,
config.hidden_size,
bias=False,
dtype=config.params_dtype,
num_layers=config.num_layers,
num_experts=num_experts,
)
self.activation_func = create_activation_func(config.activation_type)
def forward(self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher) -> torch.Tensor:
x = self.pre_norm(x, modality_dispatcher=modality_dispatcher).to(torch.bfloat16)
x = self.up_gate_proj(x, modality_dispatcher=modality_dispatcher).to(torch.float32)
x = self.activation_func(x).to(torch.bfloat16)
x = self.down_proj(x, modality_dispatcher=modality_dispatcher).to(torch.float32)
return x
def extra_repr(self) -> str:
return f"{self.up_gate_proj.weight.shape=}, {self.down_proj.weight.shape=}"
@dataclass
class AdapterConfig:
hidden_size: int
num_attention_heads: int
text_in_channels: int
video_in_channels: int
audio_in_channels: int
params_dtype: torch.dtype
class Adapter(torch.nn.Module):
config: AdapterConfig
def __init__(self, config: AdapterConfig):
super().__init__()
self.config = config
self.video_embedder = nn.Linear(config.video_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
self.text_embedder = nn.Linear(config.text_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
self.audio_embedder = nn.Linear(config.audio_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
self.rope = ElementWiseFourierEmbed(config.hidden_size // config.num_attention_heads, in_pixels=False, learnable=False)
def forward(
self,
x: torch.Tensor,
coords_mapping: torch.Tensor,
video_mask: torch.Tensor,
audio_mask: torch.Tensor,
text_mask: torch.Tensor,
):
rope = self.rope(coords_mapping)
output_x = torch.zeros(x.shape[0], self.config.hidden_size, device=x.device, dtype=x.dtype)
output_x[text_mask] = self.text_embedder(x[text_mask, : self.config.text_in_channels])
output_x[audio_mask] = self.audio_embedder(x[audio_mask, : self.config.audio_in_channels])
output_x[video_mask] = self.video_embedder(x[video_mask, : self.config.video_in_channels])
return output_x, rope
class TransFormerLayer(torch.nn.Module):
def __init__(self, config: Any, layer_idx: int):
super().__init__()
num_modality = 3 if layer_idx in config.mm_layers else 1
use_local_attn = layer_idx in config.local_attn_layers
self.post_norm = layer_idx in config.post_norm_layers
attention_config = AttentionConfig(
hidden_size=config.hidden_size,
num_heads_q=config.num_heads_q,
num_heads_kv=config.num_heads_kv,
head_dim=config.head_dim,
params_dtype=config.params_dtype,
checkpoint_qk_layernorm_rope=config.checkpoint_qk_layernorm_rope,
num_modality=num_modality,
num_layers=config.num_layers,
use_local_attn=use_local_attn,
enable_attn_gating=config.enable_attn_gating,
)
self.attention: Attention = Attention(attention_config)
activation_type = MLPActivationType.GELU7 if layer_idx in config.gelu7_layers else MLPActivationType.SWIGLU7
if activation_type == MLPActivationType.SWIGLU7:
gated_act = True
intermediate_size = int(config.hidden_size * 4 * 2 / 3) // 4 * 4
else:
gated_act = False
intermediate_size = config.hidden_size * 4
mlp_config = MLPConfig(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
activation_type=activation_type,
params_dtype=config.params_dtype,
num_modality=num_modality,
num_layers=config.num_layers,
gated_act=gated_act,
)
self.mlp: MLP = MLP(mlp_config)
if self.post_norm:
self.attn_post_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=num_modality)
self.mlp_post_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=num_modality)
def forward(
self,
hidden_states: torch.Tensor,
rope: torch.Tensor,
permute_mapping: torch.Tensor,
inv_permute_mapping: torch.Tensor,
varlen_handler: VarlenHandler,
local_attn_handler: FFAHandler,
modality_dispatcher: ModalityDispatcher,
cp_split_sizes: List[int],
) -> torch.Tensor:
attn_out = self.attention(
hidden_states,
rope,
permute_mapping,
inv_permute_mapping,
varlen_handler,
local_attn_handler,
modality_dispatcher,
cp_split_sizes,
)
if self.post_norm:
attn_out = self.attn_post_norm(attn_out, modality_dispatcher=modality_dispatcher)
hidden_states = hidden_states + attn_out
mlp_out = self.mlp(hidden_states, modality_dispatcher)
if self.post_norm:
mlp_out = self.mlp_post_norm(mlp_out, modality_dispatcher=modality_dispatcher)
hidden_states = hidden_states + mlp_out
return hidden_states
is_base_model = True
def config_patch(compile_config: CompileConfig) -> CompileConfig:
global is_base_model
if is_base_model:
is_base_model = False
else:
# Fully offload SR model for memory-constrained GPU
compile_config.offload_config.gpu_resident_weight_ratio = 0.0
return compile_config
@magi_compile(config_patch=config_patch)
class TransformerBlock(torch.nn.Module):
def __init__(self, model_config: Any):
super().__init__()
self.layers: list[TransFormerLayer] = nn.ModuleList()
for layer_idx in range(model_config.num_layers):
self.layers.append(TransFormerLayer(model_config, layer_idx))
def forward(
self,
x: torch.Tensor,
rope: torch.Tensor,
permute_mapping: torch.Tensor,
inv_permute_mapping: torch.Tensor,
varlen_handler: VarlenHandler,
local_attn_handler: FFAHandler,
modality_dispatcher: ModalityDispatcher,
cp_split_sizes: List[int],
) -> torch.Tensor:
for _, layer in enumerate(self.layers):
x = layer(
x,
rope,
permute_mapping,
inv_permute_mapping,
varlen_handler,
local_attn_handler,
modality_dispatcher,
cp_split_sizes,
)
return x
@dataclass
class TransformerConfig:
hidden_size: int
video_in_channels: int
audio_in_channels: int
text_in_channels: int
params_dtype: torch.dtype
post_process_dtype: torch.dtype
class DiTModel(torch.nn.Module):
config: TransformerConfig
def __init__(self, model_config: Any):
super().__init__()
self.config = TransformerConfig(
hidden_size=model_config.hidden_size,
video_in_channels=model_config.video_in_channels,
audio_in_channels=model_config.audio_in_channels,
text_in_channels=model_config.text_in_channels,
params_dtype=model_config.params_dtype,
post_process_dtype=torch.float32,
)
adapter_config = AdapterConfig(
hidden_size=model_config.hidden_size,
num_attention_heads=model_config.num_heads_q,
text_in_channels=model_config.text_in_channels,
video_in_channels=model_config.video_in_channels,
audio_in_channels=model_config.audio_in_channels,
params_dtype=torch.float32,
)
self.adapter: Adapter = Adapter(adapter_config)
self.block: TransformerBlock = TransformerBlock(model_config=model_config)
self.final_norm_video = MultiModalityRMSNorm(self.config.hidden_size)
self.final_norm_audio = MultiModalityRMSNorm(self.config.hidden_size)
self.final_linear_video = nn.Linear(
self.config.hidden_size, self.config.video_in_channels, bias=False, dtype=torch.float32
)
self.final_linear_audio = nn.Linear(
self.config.hidden_size, self.config.audio_in_channels, bias=False, dtype=torch.float32
)
def forward(
self,
x: torch.Tensor,
coords_mapping: torch.Tensor,
modality_mapping: torch.Tensor,
varlen_handler: VarlenHandler,
local_attn_handler: FFAHandler,
):
x = ulysses_scheduler().dispatch(x)
coords_mapping = ulysses_scheduler().dispatch(coords_mapping)
modality_mapping = ulysses_scheduler().dispatch(modality_mapping)
cp_split_sizes = ulysses_scheduler().cp_split_sizes
modality_dispatcher = ModalityDispatcher(modality_mapping, 3)
permute_mapping, inv_permute_mapping = modality_dispatcher.permute_mapping, modality_dispatcher.inv_permute_mapping
video_mask = modality_mapping == Modality.VIDEO
audio_mask = modality_mapping == Modality.AUDIO
text_mask = modality_mapping == Modality.TEXT
x, rope = self.adapter(x, coords_mapping, video_mask, audio_mask, text_mask)
x = x.to(self.config.params_dtype)
x = ModalityDispatcher.permute(x, permute_mapping)
x = self.block(
x,
rope,
permute_mapping=permute_mapping,
inv_permute_mapping=inv_permute_mapping,
varlen_handler=varlen_handler,
local_attn_handler=local_attn_handler,
modality_dispatcher=modality_dispatcher,
cp_split_sizes=cp_split_sizes,
)
x = ModalityDispatcher.inv_permute(x, inv_permute_mapping)
x_video = x[video_mask].to(self.final_norm_video.weight.dtype)
x_video = self.final_norm_video(x_video)
x_video = self.final_linear_video(x_video)
x_audio = x[audio_mask].to(self.final_norm_audio.weight.dtype)
x_audio = self.final_norm_audio(x_audio)
x_audio = self.final_linear_audio(x_audio)
x_out = torch.zeros(
x.shape[0], max(self.config.video_in_channels, self.config.audio_in_channels), device=x.device, dtype=x.dtype
)
x_out[video_mask, : self.config.video_in_channels] = x_video
x_out[audio_mask, : self.config.audio_in_channels] = x_audio
x_out = ulysses_scheduler().undispatch(x_out)
return x_out
|