STELLA EN 400M v5 - AOTInductor Optimized
Technical Optimizations
Compiler-Friendly Optimizations
- Static shapes: All tensor shapes determined at initialization
- Packed QKV projections: Single GEMM operation for Q, K, V computation
- Boolean SDPA masks: Optimized attention masks for Flash Attention
- Static RoPE cache: Precomputed cos/sin embeddings in FP32 precision
- Static position buffers: Precomputed position and token type tensors
Mathematical Preservations
- Original LayerNorm: Preserves trained normalization behavior
- Original GELU activation: Maintains trained activation patterns
- Bias terms preserved: Keeps all trained bias parameters
- Dropout disabled: Zero dropout for inference mode only
Compilation Settings
- torch.compile:
mode="max-autotune", fullgraph=True, dynamic=False - Flash Attention: Enabled with Math attention fallback
- CUDA Graphs: Enabled for consistent performance
- Precision: FP32 for accuracy, BF16 supported for speed
Usage
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Load model and tokenizer
model = AutoModel.from_pretrained("stella_en_400M_v5", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("stella_en_400M_v5", trust_remote_code=True)
# Configure for optimal performance
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')
# Compile model
model = model.eval().cuda()
compiled_model = torch.compile(
model,
mode="max-autotune",
fullgraph=True,
dynamic=False
)
# Inference with proper mean pooling
texts = ["Your text here", "Another text sample"]
with torch.inference_mode():
inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
outputs = compiled_model(**inputs)
hidden_states = outputs.last_hidden_state
# Mean pooling
mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(hidden_states.size()).float()
embeddings = torch.sum(hidden_states * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)
embeddings = F.normalize(embeddings, dim=-1)
Compatibility
- PyTorch: 2.1+ with torch.compile support
- Hardware: CUDA-capable GPUs (Ampere+ recommended for BF16)
- Precision: FP32 (guaranteed accuracy), BF16 (optimal performance)
- Backends: Optimized for Flash Attention and Triton kernels
Forked from: NovaSearch/stella_en_400M_v5
- Downloads last month
- 7
Evaluation results
- accuracy on MTEB AmazonCounterfactualClassification (en)test set self-reported92.358
- ap on MTEB AmazonCounterfactualClassification (en)test set self-reported70.813
- ap_weighted on MTEB AmazonCounterfactualClassification (en)test set self-reported70.813
- f1 on MTEB AmazonCounterfactualClassification (en)test set self-reported88.951
- f1_weighted on MTEB AmazonCounterfactualClassification (en)test set self-reported92.686
- main_score on MTEB AmazonCounterfactualClassification (en)test set self-reported92.358
- accuracy on MTEB AmazonPolarityClassificationtest set self-reported97.195
- ap on MTEB AmazonPolarityClassificationtest set self-reported96.082
- ap_weighted on MTEB AmazonPolarityClassificationtest set self-reported96.082
- f1 on MTEB AmazonPolarityClassificationtest set self-reported97.194