STELLA EN 400M v5 - AOTInductor Optimized

Technical Optimizations

Compiler-Friendly Optimizations

  • Static shapes: All tensor shapes determined at initialization
  • Packed QKV projections: Single GEMM operation for Q, K, V computation
  • Boolean SDPA masks: Optimized attention masks for Flash Attention
  • Static RoPE cache: Precomputed cos/sin embeddings in FP32 precision
  • Static position buffers: Precomputed position and token type tensors

Mathematical Preservations

  • Original LayerNorm: Preserves trained normalization behavior
  • Original GELU activation: Maintains trained activation patterns
  • Bias terms preserved: Keeps all trained bias parameters
  • Dropout disabled: Zero dropout for inference mode only

Compilation Settings

  • torch.compile: mode="max-autotune", fullgraph=True, dynamic=False
  • Flash Attention: Enabled with Math attention fallback
  • CUDA Graphs: Enabled for consistent performance
  • Precision: FP32 for accuracy, BF16 supported for speed

Usage

import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

# Load model and tokenizer
model = AutoModel.from_pretrained("stella_en_400M_v5", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("stella_en_400M_v5", trust_remote_code=True)

# Configure for optimal performance
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')

# Compile model
model = model.eval().cuda()
compiled_model = torch.compile(
    model, 
    mode="max-autotune", 
    fullgraph=True, 
    dynamic=False
)

# Inference with proper mean pooling
texts = ["Your text here", "Another text sample"]
with torch.inference_mode():
    inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")
    inputs = {k: v.cuda() for k, v in inputs.items()}
    
    outputs = compiled_model(**inputs)
    hidden_states = outputs.last_hidden_state
    
    # Mean pooling
    mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(hidden_states.size()).float()
    embeddings = torch.sum(hidden_states * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)
    embeddings = F.normalize(embeddings, dim=-1)

Compatibility

  • PyTorch: 2.1+ with torch.compile support
  • Hardware: CUDA-capable GPUs (Ampere+ recommended for BF16)
  • Precision: FP32 (guaranteed accuracy), BF16 (optimal performance)
  • Backends: Optimized for Flash Attention and Triton kernels

Forked from: NovaSearch/stella_en_400M_v5

Downloads last month
7
Safetensors
Model size
0.4B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Evaluation results