leideng/QCFuse / srt /layers /sparse_pooler.py
leideng's picture
download
raw
3.42 kB
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from sglang.srt.model_executor.model_runner import ForwardBatch
@dataclass
class SparseEmbeddingOutput:
embeddings: torch.Tensor # [batch_size, vocab_size]
class SparsePooler(nn.Module):
"""A layer that pools hidden states into sparse vocabulary-space embeddings.
This layer does the following:
1. Applies a linear transformation + ReLU to get token-level weights
2. Maps these weights to vocabulary positions using token IDs
3. Aggregates weights for repeated tokens using max pooling
4. Returns sparse embeddings in vocabulary space
Attributes:
config: Model configuration containing vocab_size and hidden_size
sparse_linear: Linear layer for computing token weights
vocab_size: Size of vocabulary for output embeddings
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
# Validate required attributes
if not hasattr(config, "vocab_size"):
raise AttributeError(
f"Config {type(config)} missing required 'vocab_size' attribute"
)
if not hasattr(config, "hidden_size"):
raise AttributeError(
f"Config {type(config)} missing required 'hidden_size' attribute"
)
self.vocab_size = config.vocab_size
self.sparse_linear = nn.Linear(config.hidden_size, 1)
self._weights_loaded = False
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> SparseEmbeddingOutput:
"""
Forward pass for sparse pooling.
Args:
hidden_states: Packed sequence hidden states [total_tokens, hidden_size]
forward_batch: Batch information with sequence lengths and input_ids
Returns:
SparseEmbeddingOutput with embeddings of shape [batch_size, vocab_size]
"""
if not self._weights_loaded:
raise ValueError(
"Sparse pooling weights not loaded. Call load_weights() first"
)
# Apply sparse linear + ReLU to get token weights
token_weights = F.relu(self.sparse_linear(hidden_states)).squeeze(
-1
) # [total_tokens]
# Create batch indices for packed sequences
batch_indices = torch.repeat_interleave(
torch.arange(
len(forward_batch.extend_seq_lens), device=hidden_states.device
),
forward_batch.extend_seq_lens,
)
# Initialize sparse embedding output
sparse_embedding = torch.zeros(
len(forward_batch.extend_seq_lens),
self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device,
)
# Map to vocabulary space using scatter_reduce with amax
flat_indices = batch_indices * self.vocab_size + forward_batch.input_ids
sparse_embedding.view(-1).scatter_reduce_(
0, flat_indices, token_weights, reduce="amax"
)
return SparseEmbeddingOutput(embeddings=sparse_embedding)
def load_weights(self, state_dict: dict):
"""Load weights from state dict (called by the model)."""
self.sparse_linear.load_state_dict(state_dict)
self._weights_loaded = True

Xet Storage Details

Size:
3.42 kB
·
Xet hash:
f7b95e132b2e3a2c9dbce8a91c4c82d80ce05fdc6e980de05ece4124efbd0a3e

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.