Instructions to use VECTORVV1/vector-V4-Pro with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use VECTORVV1/vector-V4-Pro with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="VECTORVV1/vector-V4-Pro")# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("VECTORVV1/vector-V4-Pro") model = AutoModelForCausalLM.from_pretrained("VECTORVV1/vector-V4-Pro") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use VECTORVV1/vector-V4-Pro with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "VECTORVV1/vector-V4-Pro" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "VECTORVV1/vector-V4-Pro", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/VECTORVV1/vector-V4-Pro
- SGLang
How to use VECTORVV1/vector-V4-Pro with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "VECTORVV1/vector-V4-Pro" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "VECTORVV1/vector-V4-Pro", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "VECTORVV1/vector-V4-Pro" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "VECTORVV1/vector-V4-Pro", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use VECTORVV1/vector-V4-Pro with Docker Model Runner:
docker model run hf.co/VECTORVV1/vector-V4-Pro
| import math | |
| from dataclasses import dataclass | |
| from typing import Tuple, Optional, Literal | |
| from functools import lru_cache | |
| from contextlib import contextmanager | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from kernel import act_quant, fp4_act_quant, fp8_gemm, fp4_gemm, sparse_attn, hc_split_sinkhorn | |
| world_size = 1 | |
| rank = 0 | |
| block_size = 128 | |
| fp4_block_size = 32 | |
| default_dtype = torch.bfloat16 | |
| scale_fmt = None | |
| scale_dtype = torch.float32 | |
| def set_dtype(dtype): | |
| """Temporarily override torch default dtype, restoring it on exit (even if an exception occurs).""" | |
| prev = torch.get_default_dtype() | |
| torch.set_default_dtype(dtype) | |
| try: | |
| yield | |
| finally: | |
| torch.set_default_dtype(prev) | |
| class ModelArgs: | |
| """Model hyperparameters. Field names match the config JSON keys.""" | |
| max_batch_size: int = 4 | |
| max_seq_len: int = 4096 | |
| dtype: Literal["bf16", "fp8"] = "fp8" | |
| scale_fmt: Literal[None, "ue8m0"] = "ue8m0" | |
| expert_dtype: Literal[None, "fp4"] = None | |
| scale_dtype: Literal["fp32", "fp8"] = "fp8" | |
| vocab_size: int = 129280 | |
| dim: int = 4096 | |
| moe_inter_dim: int = 4096 | |
| n_layers: int = 7 | |
| n_hash_layers: int = 0 | |
| n_mtp_layers: int = 1 | |
| n_heads: int = 64 | |
| # moe | |
| n_routed_experts: int = 8 | |
| n_shared_experts: int = 1 | |
| n_activated_experts: int = 2 | |
| score_func: Literal["softmax", "sigmoid", "sqrtsoftplus"] = "sqrtsoftplus" | |
| route_scale: float = 1. | |
| swiglu_limit: float = 0. | |
| # mqa | |
| q_lora_rank: int = 1024 | |
| head_dim: int = 512 | |
| rope_head_dim: int = 64 | |
| norm_eps: float = 1e-6 | |
| o_groups: int = 8 | |
| o_lora_rank: int = 1024 | |
| window_size: int = 128 | |
| compress_ratios: Tuple[int] = (0, 0, 4, 128, 4, 128, 4, 0) | |
| # yarn | |
| compress_rope_theta: float = 40000.0 | |
| original_seq_len: int = 0 | |
| rope_theta: float = 10000.0 | |
| rope_factor: float = 40 | |
| beta_fast: int = 32 | |
| beta_slow: int = 1 | |
| # index | |
| index_n_heads: int = 64 | |
| index_head_dim: int = 128 | |
| index_topk: int = 512 | |
| # hc | |
| hc_mult: int = 4 | |
| hc_sinkhorn_iters: int = 20 | |
| hc_eps: float = 1e-6 | |
| class ParallelEmbedding(nn.Module): | |
| """Embedding sharded along the vocab dimension. Each rank holds vocab_size // world_size rows. | |
| Out-of-range indices are zero-masked before all_reduce to combine partial embeddings.""" | |
| def __init__(self, vocab_size: int, dim: int): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.dim = dim | |
| assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" | |
| self.part_vocab_size = (vocab_size // world_size) | |
| self.vocab_start_idx = rank * self.part_vocab_size | |
| self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size | |
| self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if world_size > 1: | |
| mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) | |
| x = x - self.vocab_start_idx | |
| x[mask] = 0 | |
| y = F.embedding(x, self.weight) | |
| if world_size > 1: | |
| y[mask] = 0 | |
| dist.all_reduce(y) | |
| return y | |
| def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """Dispatches to fp4_gemm / fp8_gemm / F.linear based on weight dtype. | |
| For quantized weights, x is first quantized to FP8 via act_quant.""" | |
| assert bias is None | |
| if weight.dtype == torch.float4_e2m1fn_x2: | |
| x, s = act_quant(x, block_size, scale_fmt, scale_dtype) | |
| return fp4_gemm(x, s, weight, weight.scale, scale_dtype) | |
| elif weight.dtype == torch.float8_e4m3fn: | |
| x, s = act_quant(x, block_size, scale_fmt, scale_dtype) | |
| return fp8_gemm(x, s, weight, weight.scale, scale_dtype) | |
| else: | |
| return F.linear(x, weight) | |
| class Linear(nn.Module): | |
| """Linear layer supporting BF16, FP8, and FP4 weight formats with per-block scaling.""" | |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| dtype = dtype or default_dtype | |
| if dtype == torch.float4_e2m1fn_x2: | |
| # FP4: weight is [out, in//2] in float4_e2m1fn_x2, logically [out, in] in fp4 | |
| # Scale is [out, in//32] in float8_e8m0fnu (1 scale per 32 fp4 elements along K) | |
| self.weight = nn.Parameter(torch.empty(out_features, in_features // 2, dtype=torch.float4_e2m1fn_x2)) | |
| scale_out_features = out_features | |
| scale_in_features = in_features // fp4_block_size | |
| self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu)) | |
| elif dtype == torch.float8_e4m3fn: | |
| self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) | |
| scale_out_features = (out_features + block_size - 1) // block_size | |
| scale_in_features = (in_features + block_size - 1) // block_size | |
| self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float8_e8m0fnu)) | |
| else: | |
| self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) | |
| self.register_parameter("scale", None) | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(out_features)) | |
| else: | |
| self.register_parameter("bias", None) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return linear(x, self.weight, self.bias) | |
| class ColumnParallelLinear(Linear): | |
| """Shards output dim across TP ranks. No all-reduce needed on output.""" | |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): | |
| assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" | |
| self.part_out_features = out_features // world_size | |
| super().__init__(in_features, self.part_out_features, bias, dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return linear(x, self.weight, self.bias) | |
| class RowParallelLinear(Linear): | |
| """Shards input dim across TP ranks. All-reduce on output to sum partial results.""" | |
| def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): | |
| assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" | |
| self.part_in_features = in_features // world_size | |
| super().__init__(self.part_in_features, out_features, bias, dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| y = linear(x, self.weight, None) | |
| if world_size > 1: | |
| y = y.float() | |
| dist.all_reduce(y) | |
| if self.bias is not None: | |
| y += self.bias | |
| return y.type_as(x) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| # rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient. | |
| self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) | |
| def forward(self, x: torch.Tensor): | |
| dtype = x.dtype | |
| x = x.float() | |
| var = x.square().mean(-1, keepdim=True) | |
| x = x * torch.rsqrt(var + self.eps) | |
| return (self.weight * x).to(dtype) | |
| def precompute_freqs_cis(dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow) -> torch.Tensor: | |
| """Precomputes complex exponentials for rotary embeddings with YaRN scaling. | |
| When original_seq_len > 0, applies frequency interpolation with a smooth | |
| linear ramp between beta_fast and beta_slow correction ranges.""" | |
| def find_correction_dim(num_rotations, dim, base, max_seq_len): | |
| return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) | |
| def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): | |
| low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) | |
| high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) | |
| return max(low, 0), min(high, dim-1) | |
| def linear_ramp_factor(min, max, dim): | |
| if min == max: | |
| max += 0.001 | |
| linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) | |
| ramp_func = torch.clamp(linear_func, 0, 1) | |
| return ramp_func | |
| freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) | |
| if original_seq_len > 0: | |
| low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_seq_len) | |
| smooth = 1 - linear_ramp_factor(low, high, dim // 2) | |
| freqs = freqs / factor * (1 - smooth) + freqs * smooth | |
| t = torch.arange(seqlen) | |
| freqs = torch.outer(t, freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| return freqs_cis | |
| def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False) -> torch.Tensor: | |
| """Applies rotary positional embeddings in-place. Uses conjugate for inverse (de-rotation).""" | |
| y = x | |
| x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2))) | |
| if inverse: | |
| freqs_cis = freqs_cis.conj() | |
| if x.ndim == 3: | |
| freqs_cis = freqs_cis.view(1, x.size(1), x.size(-1)) | |
| else: | |
| freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) | |
| x = torch.view_as_real(x * freqs_cis).flatten(-2) | |
| y.copy_(x) | |
| return y | |
| def rotate_activation(x: torch.Tensor) -> torch.Tensor: | |
| """Applies randomized Hadamard rotation to spread information across dims before FP8 quant.""" | |
| assert x.dtype == torch.bfloat16 | |
| from fast_hadamard_transform import hadamard_transform | |
| return hadamard_transform(x, scale=x.size(-1) ** -0.5) | |
| def get_window_topk_idxs(window_size: int, bsz: int, seqlen: int, start_pos: int): | |
| if start_pos >= window_size - 1: | |
| start_pos %= window_size | |
| matrix = torch.cat([torch.arange(start_pos + 1, window_size), torch.arange(0, start_pos + 1)], dim=0) | |
| elif start_pos > 0: | |
| matrix = F.pad(torch.arange(start_pos + 1), (0, window_size - start_pos - 1), value=-1) | |
| else: | |
| base = torch.arange(seqlen).unsqueeze(1) | |
| matrix = (base - window_size + 1).clamp(0) + torch.arange(min(seqlen, window_size)) | |
| matrix = torch.where(matrix > base, -1, matrix) | |
| return matrix.unsqueeze(0).expand(bsz, -1, -1) | |
| def get_compress_topk_idxs(ratio: int, bsz: int, seqlen: int, start_pos: int, offset: int): | |
| if start_pos > 0: | |
| matrix = torch.arange(0, (start_pos + 1) // ratio) + offset | |
| else: | |
| matrix = torch.arange(seqlen // ratio).repeat(seqlen, 1) | |
| mask = matrix >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio | |
| matrix = torch.where(mask, -1, matrix + offset) | |
| return matrix.unsqueeze(0).expand(bsz, -1, -1) | |
| class Compressor(nn.Module): | |
| """Compresses KV cache via learned gated pooling over `compress_ratio` consecutive tokens. | |
| When overlap=True (ratio==4), uses overlapping windows for smoother compression boundaries.""" | |
| def __init__(self, args: ModelArgs, compress_ratio: int = 4, head_dim: int = 512, rotate: bool = False): | |
| super().__init__() | |
| self.dim = args.dim | |
| self.head_dim = head_dim | |
| self.rope_head_dim = args.rope_head_dim | |
| self.nope_head_dim = head_dim - args.rope_head_dim | |
| self.compress_ratio = compress_ratio | |
| self.overlap = compress_ratio == 4 | |
| self.rotate = rotate | |
| coff = 1 + self.overlap | |
| self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32)) | |
| # wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient. | |
| # When overlap, the first half of dims is for overlapping compression, second half for normal. | |
| self.wkv = Linear(self.dim, coff * self.head_dim, dtype=torch.float32) | |
| self.wgate = Linear(self.dim, coff * self.head_dim, dtype=torch.float32) | |
| self.norm = RMSNorm(self.head_dim, args.norm_eps) | |
| self.kv_cache: torch.Tensor = None # assigned lazily from Attention.kv_cache | |
| # State buffers for decode-phase incremental compression. | |
| # With overlap: state[:, :ratio] = overlapping window, state[:, ratio:] = current window. | |
| self.register_buffer("kv_state", torch.zeros(args.max_batch_size, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32), persistent=False) | |
| self.register_buffer("score_state", torch.full((args.max_batch_size, coff * compress_ratio, coff * self.head_dim), float("-inf"), dtype=torch.float32), persistent=False) | |
| self.freqs_cis: torch.Tensor = None | |
| def overlap_transform(self, tensor: torch.Tensor, value=0): | |
| # tensor: [b,s,r,2d] | |
| b, s, _, _ = tensor.size() | |
| ratio, d = self.compress_ratio, self.head_dim | |
| new_tensor = tensor.new_full((b, s, 2 * ratio, d), value) | |
| new_tensor[:, :, ratio:] = tensor[:, :, :, d:] | |
| new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d] | |
| return new_tensor | |
| def forward(self, x: torch.Tensor, start_pos: int): | |
| assert self.kv_cache is not None | |
| bsz, seqlen, _ = x.size() | |
| ratio, overlap, d, rd = self.compress_ratio, self.overlap, self.head_dim, self.rope_head_dim | |
| dtype = x.dtype | |
| # compression need fp32 | |
| x = x.float() | |
| kv = self.wkv(x) | |
| score = self.wgate(x) | |
| if start_pos == 0: | |
| should_compress = seqlen >= ratio | |
| remainder = seqlen % ratio | |
| cutoff = seqlen - remainder | |
| offset = ratio if overlap else 0 | |
| if overlap and cutoff >= ratio: | |
| self.kv_state[:bsz, :ratio] = kv[:, cutoff-ratio : cutoff] | |
| self.score_state[:bsz, :ratio] = score[:, cutoff-ratio : cutoff] + self.ape | |
| if remainder > 0: | |
| kv, self.kv_state[:bsz, offset : offset+remainder] = kv.split([cutoff, remainder], dim=1) | |
| self.score_state[:bsz, offset : offset+remainder] = score[:, cutoff:] + self.ape[:remainder] | |
| score = score[:, :cutoff] | |
| kv = kv.unflatten(1, (-1, ratio)) | |
| score = score.unflatten(1, (-1, ratio)) + self.ape | |
| if overlap: | |
| kv = self.overlap_transform(kv, 0) | |
| score = self.overlap_transform(score, float("-inf")) | |
| kv = (kv * score.softmax(dim=2)).sum(dim=2) | |
| else: | |
| should_compress = (start_pos + 1) % self.compress_ratio == 0 | |
| score += self.ape[start_pos % ratio] | |
| if overlap: | |
| self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1) | |
| self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1) | |
| if should_compress: | |
| kv_state = torch.cat([self.kv_state[:bsz, :ratio, :d], self.kv_state[:bsz, ratio:, d:]], dim=1) | |
| score_state = torch.cat([self.score_state[:bsz, :ratio, :d], self.score_state[:bsz, ratio:, d:]], dim=1) | |
| kv = (kv_state * score_state.softmax(dim=1)).sum(dim=1, keepdim=True) | |
| self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:] | |
| self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:] | |
| else: | |
| self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1) | |
| self.score_state[:bsz, start_pos % ratio] = score.squeeze(1) | |
| if should_compress: | |
| kv = (self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1)).sum(dim=1, keepdim=True) | |
| if not should_compress: | |
| return | |
| kv = self.norm(kv.to(dtype)) | |
| if start_pos == 0: | |
| freqs_cis = self.freqs_cis[:cutoff:ratio] | |
| else: | |
| freqs_cis = self.freqs_cis[start_pos + 1 - self.compress_ratio].unsqueeze(0) | |
| apply_rotary_emb(kv[..., -rd:], freqs_cis) | |
| if self.rotate: | |
| kv = rotate_activation(kv) | |
| fp4_act_quant(kv, fp4_block_size, True) | |
| else: | |
| act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True) | |
| if start_pos == 0: | |
| self.kv_cache[:bsz, :seqlen // ratio] = kv | |
| else: | |
| self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1) | |
| return kv | |
| class Indexer(torch.nn.Module): | |
| """Selects top-k compressed KV positions for sparse attention via learned scoring. | |
| Has its own Compressor (with Hadamard rotation) to build compressed KV for scoring.""" | |
| def __init__(self, args: ModelArgs, compress_ratio: int = 4): | |
| super().__init__() | |
| self.dim = args.dim | |
| self.n_heads = args.index_n_heads | |
| self.n_local_heads = args.index_n_heads // world_size | |
| self.head_dim = args.index_head_dim | |
| self.rope_head_dim = args.rope_head_dim | |
| self.index_topk = args.index_topk | |
| self.q_lora_rank = args.q_lora_rank | |
| self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim) | |
| self.weights_proj = ColumnParallelLinear(self.dim, self.n_heads, dtype=torch.bfloat16) | |
| self.softmax_scale = self.head_dim ** -0.5 | |
| self.compress_ratio = compress_ratio | |
| self.compressor = Compressor(args, compress_ratio, self.head_dim, True) | |
| self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len // compress_ratio, self.head_dim), persistent=False) | |
| self.freqs_cis = None | |
| def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, offset: int): | |
| bsz, seqlen, _ = x.size() | |
| freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] | |
| ratio = self.compress_ratio | |
| rd = self.rope_head_dim | |
| end_pos = start_pos + seqlen | |
| if self.compressor.kv_cache is None: | |
| self.compressor.kv_cache = self.kv_cache | |
| self.compressor.freqs_cis = self.freqs_cis | |
| q = self.wq_b(qr) | |
| q = q.unflatten(-1, (self.n_local_heads, self.head_dim)) | |
| apply_rotary_emb(q[..., -rd:], freqs_cis) | |
| q = rotate_activation(q) | |
| # use fp4 simulation for q and kv in indexer | |
| fp4_act_quant(q, fp4_block_size, True) | |
| self.compressor(x, start_pos) | |
| weights = self.weights_proj(x) * (self.softmax_scale * self.n_heads ** -0.5) | |
| # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16 | |
| index_score = torch.einsum("bshd,btd->bsht", q, self.kv_cache[:bsz, :end_pos // ratio]) | |
| index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2) | |
| if world_size > 1: | |
| dist.all_reduce(index_score) | |
| if start_pos == 0: | |
| mask = torch.arange(seqlen // ratio).repeat(seqlen, 1) >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio | |
| index_score += torch.where(mask, float("-inf"), 0) | |
| topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1] | |
| if start_pos == 0: | |
| mask = topk_idxs >= torch.arange(1, seqlen + 1).unsqueeze(1) // ratio | |
| topk_idxs = torch.where(mask, -1, topk_idxs + offset) | |
| else: | |
| topk_idxs += offset | |
| return topk_idxs | |
| class Attention(nn.Module): | |
| """Multi-head Latent Attention (MLA) with sliding window + optional KV compression. | |
| Uses low-rank Q projection (wq_a -> q_norm -> wq_b) and grouped low-rank O projection.""" | |
| def __init__(self, layer_id: int, args: ModelArgs): | |
| super().__init__() | |
| self.layer_id = layer_id | |
| self.dim = args.dim | |
| self.n_heads = args.n_heads | |
| self.n_local_heads = args.n_heads // world_size | |
| self.q_lora_rank = args.q_lora_rank | |
| self.o_lora_rank = args.o_lora_rank | |
| self.head_dim = args.head_dim | |
| self.rope_head_dim = args.rope_head_dim | |
| self.nope_head_dim = args.head_dim - args.rope_head_dim | |
| self.n_groups = args.o_groups | |
| self.n_local_groups = self.n_groups // world_size | |
| self.window_size = args.window_size | |
| self.compress_ratio = args.compress_ratios[layer_id] | |
| self.eps = args.norm_eps | |
| self.attn_sink = nn.Parameter(torch.empty(self.n_local_heads, dtype=torch.float32)) | |
| self.wq_a = Linear(self.dim, self.q_lora_rank) | |
| self.q_norm = RMSNorm(self.q_lora_rank, self.eps) | |
| self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.head_dim) | |
| self.wkv = Linear(self.dim, self.head_dim) | |
| self.kv_norm = RMSNorm(self.head_dim, self.eps) | |
| self.wo_a = ColumnParallelLinear(self.n_heads * self.head_dim // self.n_groups, self.n_groups * args.o_lora_rank, dtype=torch.bfloat16) | |
| self.wo_b = RowParallelLinear(self.n_groups * args.o_lora_rank, self.dim) | |
| self.softmax_scale = self.head_dim ** -0.5 | |
| if self.compress_ratio: | |
| self.compressor = Compressor(args, self.compress_ratio, self.head_dim) | |
| if self.compress_ratio == 4: | |
| self.indexer = Indexer(args, self.compress_ratio) | |
| else: | |
| self.indexer = None | |
| kv_cache_size = args.window_size + (args.max_seq_len // self.compress_ratio if self.compress_ratio else 0) | |
| self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, kv_cache_size, self.head_dim), persistent=False) | |
| if self.compress_ratio: | |
| original_seq_len, rope_theta = args.original_seq_len, args.compress_rope_theta | |
| else: | |
| # disable YaRN and use base rope_theta in pure sliding-window attention | |
| original_seq_len, rope_theta = 0, args.rope_theta | |
| freqs_cis = precompute_freqs_cis(self.rope_head_dim, args.max_seq_len, original_seq_len, | |
| rope_theta, args.rope_factor, args.beta_fast, args.beta_slow) | |
| self.register_buffer("freqs_cis", freqs_cis, persistent=False) | |
| def forward(self, x: torch.Tensor, start_pos: int): | |
| bsz, seqlen, _ = x.size() | |
| freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] | |
| win = self.window_size | |
| ratio = self.compress_ratio | |
| rd = self.rope_head_dim | |
| if self.compress_ratio and self.compressor.kv_cache is None: | |
| self.compressor.kv_cache = self.kv_cache[:, win:] | |
| self.compressor.freqs_cis = self.freqs_cis | |
| if self.indexer is not None: | |
| self.indexer.freqs_cis = self.freqs_cis | |
| # q | |
| qr = q = self.q_norm(self.wq_a(x)) | |
| q = self.wq_b(q).unflatten(-1, (self.n_local_heads, self.head_dim)) | |
| q *= torch.rsqrt(q.square().mean(-1, keepdim=True) + self.eps) | |
| apply_rotary_emb(q[..., -rd:], freqs_cis) | |
| # win kv & topk_idxs | |
| kv = self.wkv(x) | |
| kv = self.kv_norm(kv) | |
| apply_rotary_emb(kv[..., -rd:], freqs_cis) | |
| # FP8-simulate non-rope dims to match QAT; rope dims stay bf16 for positional precision | |
| act_quant(kv[..., :-rd], 64, scale_fmt, scale_dtype, True) | |
| topk_idxs = get_window_topk_idxs(win, bsz, seqlen, start_pos) | |
| if self.compress_ratio: | |
| offset = kv.size(1) if start_pos == 0 else win | |
| if self.indexer is not None: | |
| compress_topk_idxs = self.indexer(x, qr, start_pos, offset) | |
| else: | |
| compress_topk_idxs = get_compress_topk_idxs(ratio, bsz, seqlen, start_pos, offset) | |
| topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1) | |
| topk_idxs = topk_idxs.int() | |
| # compress kv & attn | |
| if start_pos == 0: | |
| if seqlen <= win: | |
| self.kv_cache[:bsz, :seqlen] = kv | |
| else: | |
| cutoff = seqlen % win | |
| self.kv_cache[:bsz, cutoff: win], self.kv_cache[:bsz, :cutoff] = kv[:, -win:].split([win - cutoff, cutoff], dim=1) | |
| if self.compress_ratio: | |
| if (kv_compress := self.compressor(x, start_pos)) is not None: | |
| kv = torch.cat([kv, kv_compress], dim=1) | |
| # We performed QAT here, kv could also use fp8 format, though current implementation uses bf16 | |
| o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale) | |
| else: | |
| self.kv_cache[:bsz, start_pos % win] = kv.squeeze(1) | |
| if self.compress_ratio: | |
| self.compressor(x, start_pos) | |
| o = sparse_attn(q, self.kv_cache[:bsz], self.attn_sink, topk_idxs, self.softmax_scale) | |
| apply_rotary_emb(o[..., -rd:], freqs_cis, True) | |
| # o | |
| o = o.view(bsz, seqlen, self.n_local_groups, -1) | |
| wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) | |
| # NOTE: wo_a is FP8 in checkpoint; could do FP8 einsum here for better perf, | |
| # but using BF16 for simplicity. | |
| o = torch.einsum("bsgd,grd->bsgr", o, wo_a) | |
| x = self.wo_b(o.flatten(2)) | |
| return x | |
| class Gate(nn.Module): | |
| """MoE gating: computes expert routing scores and selects top-k experts. | |
| Supports hash-based routing (first n_hash_layers) where expert indices are | |
| predetermined per token ID, and score-based routing (remaining layers).""" | |
| def __init__(self, layer_id: int, args: ModelArgs): | |
| super().__init__() | |
| self.dim = args.dim | |
| self.topk = args.n_activated_experts | |
| self.score_func = args.score_func | |
| self.route_scale = args.route_scale | |
| self.hash = layer_id < args.n_hash_layers | |
| self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) | |
| if self.hash: | |
| self.tid2eid = nn.Parameter(torch.empty(args.vocab_size, args.n_activated_experts, dtype=torch.int32), requires_grad=False) | |
| self.bias = None | |
| else: | |
| self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) | |
| def forward(self, x: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| scores = linear(x.float(), self.weight.float()) | |
| if self.score_func == "softmax": | |
| scores = scores.softmax(dim=-1) | |
| elif self.score_func == "sigmoid": | |
| scores = scores.sigmoid() | |
| else: | |
| scores = F.softplus(scores).sqrt() | |
| original_scores = scores | |
| # Bias shifts scores for expert selection (topk) but does not affect routing weights. | |
| if self.bias is not None: | |
| scores = scores + self.bias | |
| if self.hash: | |
| indices = self.tid2eid[input_ids] | |
| else: | |
| indices = scores.topk(self.topk, dim=-1)[1] | |
| weights = original_scores.gather(1, indices) | |
| if self.score_func != "softmax": | |
| weights /= weights.sum(dim=-1, keepdim=True) | |
| weights *= self.route_scale | |
| return weights, indices | |
| class Expert(nn.Module): | |
| """Single MoE expert: SwiGLU FFN (w1, w2, w3). Computation in float32 for stability.""" | |
| def __init__(self, dim: int, inter_dim: int, dtype=None, swiglu_limit=0): | |
| super().__init__() | |
| self.w1 = Linear(dim, inter_dim, dtype=dtype) | |
| self.w2 = Linear(inter_dim, dim, dtype=dtype) | |
| self.w3 = Linear(dim, inter_dim, dtype=dtype) | |
| self.swiglu_limit = swiglu_limit | |
| def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| dtype = x.dtype | |
| gate = self.w1(x).float() | |
| up = self.w3(x).float() | |
| if self.swiglu_limit > 0: | |
| up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit) | |
| gate = torch.clamp(gate, max=self.swiglu_limit) | |
| x = F.silu(gate) * up | |
| if weights is not None: | |
| x = weights * x | |
| return self.w2(x.to(dtype)) | |
| class MoE(nn.Module): | |
| """Mixture-of-Experts: gate routes each token to top-k routed experts + 1 shared expert. | |
| Experts are sharded across TP ranks; each rank handles n_routed_experts // world_size experts.""" | |
| def __init__(self, layer_id: int, args: ModelArgs): | |
| super().__init__() | |
| self.layer_id = layer_id | |
| self.dim = args.dim | |
| assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" | |
| self.n_routed_experts = args.n_routed_experts | |
| self.n_local_experts = args.n_routed_experts // world_size | |
| self.n_activated_experts = args.n_activated_experts | |
| self.experts_start_idx = rank * self.n_local_experts | |
| self.experts_end_idx = self.experts_start_idx + self.n_local_experts | |
| self.gate = Gate(layer_id, args) | |
| expert_dtype = torch.float4_e2m1fn_x2 if args.expert_dtype == "fp4" else None | |
| self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None | |
| for i in range(self.n_routed_experts)]) | |
| assert args.n_shared_experts == 1 | |
| # no swiglu_limit | |
| self.shared_experts = Expert(args.dim, args.moe_inter_dim) | |
| def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: | |
| shape = x.size() | |
| x = x.view(-1, self.dim) | |
| weights, indices = self.gate(x, input_ids.flatten()) | |
| y = torch.zeros_like(x, dtype=torch.float32) | |
| counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() | |
| for i in range(self.experts_start_idx, self.experts_end_idx): | |
| if counts[i] == 0: | |
| continue | |
| expert = self.experts[i] | |
| idx, top = torch.where(indices == i) | |
| y[idx] += expert(x[idx], weights[idx, top, None]) | |
| if world_size > 1: | |
| dist.all_reduce(y) | |
| y += self.shared_experts(x) | |
| return y.type_as(x).view(shape) | |
| class Block(nn.Module): | |
| """Transformer block with Hyper-Connections (HC) mixing. | |
| Instead of a simple residual, HC maintains `hc_mult` copies of the hidden state. | |
| hc_pre: reduces hc copies -> 1 via learned weighted sum (pre-weights from Sinkhorn). | |
| hc_post: expands 1 -> hc copies via learned post-weights + combination matrix.""" | |
| def __init__(self, layer_id: int, args: ModelArgs): | |
| super().__init__() | |
| self.layer_id = layer_id | |
| self.norm_eps = args.norm_eps | |
| self.attn = Attention(layer_id, args) | |
| self.ffn = MoE(layer_id, args) | |
| self.attn_norm = RMSNorm(args.dim, self.norm_eps) | |
| self.ffn_norm = RMSNorm(args.dim, self.norm_eps) | |
| self.hc_mult = hc_mult = args.hc_mult | |
| self.hc_sinkhorn_iters = args.hc_sinkhorn_iters | |
| self.hc_eps = args.hc_eps | |
| mix_hc = (2 + hc_mult) * hc_mult | |
| hc_dim = hc_mult * args.dim | |
| with set_dtype(torch.float32): | |
| self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim)) | |
| self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim)) | |
| self.hc_attn_base = nn.Parameter(torch.empty(mix_hc)) | |
| self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc)) | |
| self.hc_attn_scale = nn.Parameter(torch.empty(3)) | |
| self.hc_ffn_scale = nn.Parameter(torch.empty(3)) | |
| def hc_pre(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor): | |
| # x: [b,s,hc,d], hc_fn: [mix_hc,hc*d], hc_scale: [3], hc_base: [mix_hc], y: [b,s,hc,d] | |
| shape, dtype = x.size(), x.dtype | |
| x = x.flatten(2).float() | |
| rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) | |
| mixes = F.linear(x, hc_fn) * rsqrt | |
| pre, post, comb = hc_split_sinkhorn(mixes, hc_scale, hc_base, self.hc_mult, self.hc_sinkhorn_iters, self.hc_eps) | |
| y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2) | |
| return y.to(dtype), post, comb | |
| def hc_post(self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor): | |
| # x: [b,s,d], residual: [b,s,hc,d], post: [b,s,hc], comb: [b,s,hc,hc], y: [b,s,hc,d] | |
| y = post.unsqueeze(-1) * x.unsqueeze(-2) + torch.sum(comb.unsqueeze(-1) * residual.unsqueeze(-2), dim=2) | |
| return y.type_as(x) | |
| def forward(self, x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor]) -> torch.Tensor: | |
| residual = x | |
| x, post, comb = self.hc_pre(x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base) | |
| x = self.attn_norm(x) | |
| x = self.attn(x, start_pos) | |
| x = self.hc_post(x, residual, post, comb) | |
| residual = x | |
| x, post, comb = self.hc_pre(x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base) | |
| x = self.ffn_norm(x) | |
| x = self.ffn(x, input_ids) | |
| x = self.hc_post(x, residual, post, comb) | |
| return x | |
| class ParallelHead(nn.Module): | |
| def __init__(self, vocab_size: int, dim: int, norm_eps: float = 1e-6, hc_eps: float = 1e-6): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.dim = dim | |
| self.norm_eps = norm_eps | |
| self.hc_eps = hc_eps | |
| self.part_vocab_size = (vocab_size // world_size) | |
| # lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later. | |
| self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim, dtype=torch.float32)) | |
| def get_logits(self, x): | |
| return F.linear(x[:, -1].float(), self.weight) | |
| def forward(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, norm: RMSNorm): | |
| # x: [b,s,hc,d] | |
| x = self.hc_head(x, hc_fn, hc_scale, hc_base) | |
| logits = self.get_logits(norm(x)) | |
| if world_size > 1: | |
| all_logits = [torch.empty_like(logits) for _ in range(world_size)] | |
| dist.all_gather(all_logits, logits) | |
| logits = torch.cat(all_logits, dim=-1) | |
| return logits | |
| def hc_head(self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor): | |
| shape, dtype = x.size(), x.dtype | |
| x = x.flatten(2).float() | |
| rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) | |
| mixes = F.linear(x, hc_fn) * rsqrt | |
| pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps | |
| y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=2) | |
| return y.to(dtype) | |
| class MTPBlock(Block): | |
| def __init__(self, layer_id: int, args: ModelArgs): | |
| super().__init__(layer_id, args) | |
| self.e_proj = Linear(args.dim, args.dim) | |
| self.h_proj = Linear(args.dim, args.dim) | |
| self.enorm = RMSNorm(args.dim, args.norm_eps) | |
| self.hnorm = RMSNorm(args.dim, args.norm_eps) | |
| self.norm = RMSNorm(args.dim, args.norm_eps) | |
| self.hc_mult = hc_mult = args.hc_mult | |
| hc_dim = hc_mult * args.dim | |
| with set_dtype(torch.float32): | |
| self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim)) | |
| self.hc_head_base = nn.Parameter(torch.empty(hc_mult)) | |
| self.hc_head_scale = nn.Parameter(torch.empty(1)) | |
| self.embed: ParallelEmbedding = None | |
| self.head: ParallelHead = None | |
| def forward(self, x: torch.Tensor, start_pos: int, input_ids: torch.Tensor) -> torch.Tensor: | |
| # x: [b,s,hc,d] | |
| assert self.embed is not None and self.head is not None | |
| e = self.embed(input_ids) | |
| e = self.enorm(e) | |
| x = self.hnorm(x) | |
| x = self.e_proj(e).unsqueeze(2) + self.h_proj(x) | |
| x = super().forward(x, start_pos, input_ids) | |
| logits = self.head(x, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm) | |
| return logits | |
| class Transformer(nn.Module): | |
| """Full DeepSeek-V4 model: embed -> HC-expand -> N blocks -> HC-head -> logits. | |
| Sets global state (world_size, rank, default_dtype, scale_fmt, scale_dtype) in __init__.""" | |
| def __init__(self, args: ModelArgs): | |
| global world_size, rank, default_dtype, scale_fmt, scale_dtype | |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| rank = dist.get_rank() if dist.is_initialized() else 0 | |
| default_dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 | |
| scale_fmt = "ue8m0" if args.scale_dtype == "fp8" else args.scale_fmt | |
| scale_dtype = torch.float8_e8m0fnu if args.scale_dtype == "fp8" else torch.float32 | |
| super().__init__() | |
| self.max_seq_len = args.max_seq_len | |
| self.norm_eps = args.norm_eps | |
| self.hc_eps = args.hc_eps | |
| self.embed = ParallelEmbedding(args.vocab_size, args.dim) | |
| self.layers = torch.nn.ModuleList() | |
| for layer_id in range(args.n_layers): | |
| self.layers.append(Block(layer_id, args)) | |
| self.norm = RMSNorm(args.dim, self.norm_eps) | |
| self.head = ParallelHead(args.vocab_size, args.dim, self.norm_eps, self.hc_eps) | |
| self.mtp = torch.nn.ModuleList() | |
| for layer_id in range(args.n_mtp_layers): | |
| self.mtp.append(MTPBlock(args.n_layers + layer_id, args)) | |
| self.mtp[-1].embed = self.embed | |
| self.mtp[-1].head = self.head | |
| self.hc_mult = hc_mult = args.hc_mult | |
| hc_dim = hc_mult * args.dim | |
| with set_dtype(torch.float32): | |
| self.hc_head_fn = nn.Parameter(torch.empty(hc_mult, hc_dim)) | |
| self.hc_head_base = nn.Parameter(torch.empty(hc_mult)) | |
| self.hc_head_scale = nn.Parameter(torch.empty(1)) | |
| def forward(self, input_ids: torch.Tensor, start_pos: int = 0): | |
| h = self.embed(input_ids) | |
| # Expand to hc_mult copies for Hyper-Connections | |
| h = h.unsqueeze(2).repeat(1, 1, self.hc_mult, 1) | |
| for layer in self.layers: | |
| h = layer(h, start_pos, input_ids) | |
| logits = self.head(h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm) | |
| return logits | |
| if __name__ == "__main__": | |
| torch.set_default_dtype(torch.bfloat16) | |
| torch.set_default_device("cuda") | |
| torch.manual_seed(0) | |
| args = ModelArgs(n_hash_layers=0) | |
| x = torch.randint(0, args.vocab_size, (2, 128)) | |
| model = Transformer(args) | |
| print(model(x).size()) | |
| for i in range(128, 150): | |
| print(i, model(x[:, 0:1], i).size()) | |
| h = torch.randn(2, 128, args.hc_mult, args.dim) | |
| mtp = model.mtp[0] | |
| print(mtp(h, 0, x).size()) | |
| print(mtp(h[:, 0:1], 1, x[:, 0:1]).size()) | |