--- license: apache-2.0 tags: - neuron - aws-inferentia - inf2 - moe - pre-compiled - neuronx-distributed-inference base_model: arcee-ai/Trinity-Nano-Preview library_name: neuronx-distributed-inference --- # Trinity-Nano Pre-Compiled for AWS Inferentia2 (TP=1) Pre-compiled and pre-sharded [Trinity-Nano-Preview](https://huggingface.co/arcee-ai/Trinity-Nano-Preview) (~6B total, ~1B active MoE) for AWS Neuron SDK 2.28, ready to load on **inf2.xlarge** (16GB system RAM) or any larger Inferentia2/Trainium instance. ## Why Pre-Sharded? The standard NxDI load path downloads the full HuggingFace checkpoint (~12GB bf16) into CPU RAM for weight conversion and sharding. On inf2.xlarge (16GB system RAM), this causes an OOM kill at 15+ GB RSS. Pre-sharded weights bypass this entirely — NxDI reads directly from the per-rank sharded files, using only **1.4 GB RSS** (12.6% of system RAM). ## Contents | File | Size | Description | |------|------|-------------| | `model.pt` | 49 MB | Compiled Neuron NEFF graphs | | `neuron_config.json` | 9 KB | NxDI configuration (TP=1, BS=1, seq_len=2048, bf16) | | `weights/tp0_sharded_checkpoint.safetensors` | 12 GB | Pre-sharded model weights for rank 0 | ## Performance Measured on inf2.xlarge (1 NeuronCore, 16GB system RAM): | Metric | Value | |--------|-------| | TTFT | 706 ms | | TKG (per token) | 9.0 ms | | Throughput | 112 tok/s | | Load time | 18.4 s | | Peak RSS | 1.39 GB | ## Quick Start ### Prerequisites - AWS instance with Inferentia2: inf2.xlarge, inf2.8xlarge, or larger - [Deep Learning AMI Neuron (Ubuntu 24.04) 20260227](https://aws.amazon.com/marketplace/) (SDK 2.28) - Activate the pre-installed venv: `source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate` ### 1. Clone the model implementation The Trinity Neuron implementation is not yet merged into the main NxDI repo. Use the contrib branch from the fork: ```bash git clone --branch contrib/trinity-model --single-branch \ https://github.com/jimburtoft/neuronx-distributed-inference.git nxdi-trinity ``` ### 2. Download this artifact and the base model config/tokenizer ```python from huggingface_hub import snapshot_download # Download the pre-compiled artifact (model.pt + sharded weights) snapshot_download("jburtoft/Trinity-Nano-Neuron-TP1", local_dir="/home/ubuntu/Trinity-Nano-Neuron-TP1") # Download config + tokenizer only (no model weights needed) snapshot_download("arcee-ai/Trinity-Nano-Preview", local_dir="/home/ubuntu/Trinity-Nano-Preview", ignore_patterns=["*.safetensors", "*.bin", "*.pt", "*.gguf"]) ``` ### 3. Load and run inference ```python import sys import torch from transformers import AutoTokenizer from neuronx_distributed_inference.models.config import MoENeuronConfig # Point to the Trinity implementation from the cloned repo sys.path.insert(0, "/home/ubuntu/nxdi-trinity/contrib/models/Trinity/src") from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig # Build model with save_sharded_checkpoint=True (must match compilation) neuron_config = MoENeuronConfig( tp_degree=1, batch_size=1, seq_len=2048, torch_dtype=torch.bfloat16, save_sharded_checkpoint=True, ) config = TrinityInferenceConfig.from_pretrained( "/home/ubuntu/Trinity-Nano-Preview", neuron_config=neuron_config, ) model = NeuronTrinityForCausalLM("/home/ubuntu/Trinity-Nano-Preview", config) model.load("/home/ubuntu/Trinity-Nano-Neuron-TP1") # Tokenize tokenizer = AutoTokenizer.from_pretrained( "/home/ubuntu/Trinity-Nano-Preview", trust_remote_code=True ) prompt = "Hello, how are you today?" inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs.input_ids # Generate model.reset() position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) seq_ids = torch.arange(1) with torch.no_grad(): outputs = model(input_ids, position_ids=position_ids, seq_ids=seq_ids) logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] next_token = torch.argmax(logits[:, -1, :], dim=-1) print(f"Prompt: {prompt}") print(f"Next token: {tokenizer.decode(next_token)}") # Autoregressive generation generated = [next_token.unsqueeze(0)] for i in range(31): pos = torch.tensor([[input_ids.shape[1] + i]]) with torch.no_grad(): outputs = model(generated[-1], position_ids=pos, seq_ids=seq_ids) logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] next_token = torch.argmax(logits[:, -1, :], dim=-1) generated.append(next_token.unsqueeze(0)) text = tokenizer.decode(torch.cat(generated, dim=1)[0], skip_special_tokens=True) print(f"Generated: {text}") ``` ## Compilation Details | Parameter | Value | |-----------|-------| | SDK | 2.28 (NxDI 0.8.16251, neuronx-cc 2.23.6484, torch-neuronx 2.9.0.2.12) | | TP degree | 1 | | Batch size | 1 | | Sequence length | 2048 | | Dtype | bfloat16 | | `save_sharded_checkpoint` | True | ## Compiling Your Own To compile for different configurations (e.g., TP=2, BS=4), you need a larger instance (inf2.8xlarge or trn2.3xlarge): ```python import sys import torch from neuronx_distributed_inference.models.config import MoENeuronConfig sys.path.insert(0, "/path/to/nxdi-trinity/contrib/models/Trinity/src") from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig neuron_config = MoENeuronConfig( tp_degree=1, # Adjust as needed batch_size=1, # Adjust as needed seq_len=2048, # Adjust as needed torch_dtype=torch.bfloat16, save_sharded_checkpoint=True, # Required for pre-sharded deployment ) config = TrinityInferenceConfig.from_pretrained( "/path/to/Trinity-Nano-Preview", neuron_config=neuron_config ) model = NeuronTrinityForCausalLM("/path/to/Trinity-Nano-Preview", config) model.compile("/path/to/compiled-output") # Output: model.pt, neuron_config.json, weights/tp{rank}_sharded_checkpoint.safetensors ``` ## Base Model - **Model**: [arcee-ai/Trinity-Nano-Preview](https://huggingface.co/arcee-ai/Trinity-Nano-Preview) - **Architecture**: MoE (128 experts, top-8 active, 1 shared expert) - **Parameters**: ~6B total, ~1B active per token - **License**: Apache 2.0 ## Model Implementation The NeuronX Distributed Inference implementation for Trinity is available at: [github.com/jimburtoft/neuronx-distributed-inference](https://github.com/jimburtoft/neuronx-distributed-inference/tree/contrib/trinity-model/contrib/models/Trinity) (branch: `contrib/trinity-model`) This implementation supports all three Trinity model sizes (Nano, Mini, Large) with a single unified `modeling_trinity.py`.