Buckets:
| # Quickstart | |
| ๐ค Optimum Neuron makes AWS accelerator adoption seamless for Hugging Face users with **drop-in replacements** for standard training and inference components. | |
| ***๐ Need to set up your environment first?** Check out our [Getting Started on EC2](getting-started-on-ec2) page for complete installation and AWS setup instructions.* | |
| **Key Features:** | |
| - ๐ **Drop-in replacement** for standard Transformers training and inference | |
| - โก **Distributed training** support with minimal code changes | |
| - ๐ฏ **Optimized models** for AWS accelerators | |
| - ๐ **Production-ready** inference with compiled models | |
| ## Training | |
| Training on AWS Trainium requires minimal changes to your existing code - just swap in Optimum Neuron's drop-in replacements: | |
| ```python | |
| import torch | |
| import torch_xla.runtime as xr | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| # Optimum Neuron's drop-in replacements for standard training components | |
| from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArguments | |
| from optimum.neuron.models.training import NeuronModelForCausalLM | |
| def format_dolly_dataset(example): | |
| """Format Dolly dataset into instruction-following format.""" | |
| instruction = f"### Instruction\n{example['instruction']}" | |
| context = f"### Context\n{example['context']}" if example["context"] else None | |
| response = f"### Answer\n{example['response']}" | |
| # Combine all parts with double newlines | |
| parts = [instruction, context, response] | |
| return "\n\n".join(part for part in parts if part) | |
| def main(): | |
| # Load instruction-following dataset | |
| dataset = load_dataset("databricks/databricks-dolly-15k", split="train") | |
| # Model configuration | |
| model_id = "Qwen/Qwen3-1.7B" | |
| output_dir = "qwen3-1.7b-finetuned" | |
| # Setup tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Configure training for Trainium | |
| training_args = NeuronTrainingArguments( | |
| learning_rate=1e-4, | |
| tensor_parallel_size=8, # Split model across 8 accelerators | |
| per_device_train_batch_size=1, # Batch size per device | |
| gradient_accumulation_steps=8, | |
| logging_steps=1, | |
| output_dir=output_dir, | |
| ) | |
| # Load model optimized for Trainium | |
| model = NeuronModelForCausalLM.from_pretrained( | |
| model_id, | |
| training_args.trn_config, | |
| dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", # Enable flash attention | |
| ) | |
| # Setup supervised fine-tuning | |
| sft_config = NeuronSFTConfig( | |
| max_seq_length=2048, | |
| packing=True, # Pack multiple samples for efficiency | |
| **training_args.to_dict(), | |
| ) | |
| # Initialize trainer and start training | |
| trainer = NeuronSFTTrainer( | |
| model=model, | |
| args=sft_config, | |
| tokenizer=tokenizer, | |
| train_dataset=dataset, | |
| formatting_func=format_dolly_dataset, | |
| ) | |
| trainer.train() | |
| # Share your model with the community | |
| trainer.push_to_hub( | |
| commit_message="Fine-tuned on Databricks Dolly dataset", | |
| blocking=True, | |
| model_name=output_dir, | |
| ) | |
| if xr.local_ordinal() == 0: | |
| print(f"Training complete! Model saved to {output_dir}") | |
| if __name__ == "__main__": | |
| main() | |
| ``` | |
| This example demonstrates supervised fine-tuning on the [Databricks Dolly dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k) using `NeuronSFTTrainer` and `NeuronModelForCausalLM` - the Trainium-optimized versions of standard Transformers components. | |
| ### Running Training | |
| **Compilation** (optional for first run): | |
| ```bash | |
| NEURON_CC_FLAGS="--model-type transformer" neuron_parallel_compile torchrun --nproc_per_node 32 sft_finetune_qwen3.py | |
| ``` | |
| **Training:** | |
| ```bash | |
| NEURON_CC_FLAGS="--model-type transformer" torchrun --nproc_per_node 32 sft_finetune_qwen3.py | |
| ``` | |
| ## Inference | |
| Optimized inference requires two steps: **export** your model to Neuron format, then **run** it with `NeuronModelForXXX` classes. | |
| ### 1. Export Your Model | |
| ```bash | |
| optimum-cli export neuron \ | |
| --model distilbert-base-uncased-finetuned-sst-2-english \ | |
| --batch_size 1 \ | |
| --sequence_length 32 \ | |
| --auto_cast matmul \ | |
| --auto_cast_type bf16 \ | |
| distilbert_base_uncased_finetuned_sst2_english_neuron/ | |
| ``` | |
| This exports the model with optimized settings: static shapes (`batch_size=1`, `sequence_length=32`) and BF16 precision for `matmul` operations. Check out the [exporter guide](https://huggingface.co/docs/optimum-neuron/guides/export_model) for more compilation options. | |
| ### 2. Run Inference | |
| ```python | |
| from transformers import AutoTokenizer | |
| from optimum.neuron import NeuronModelForSequenceClassification | |
| # Load the compiled Neuron model | |
| model = NeuronModelForSequenceClassification.from_pretrained( | |
| "distilbert_base_uncased_finetuned_sst2_english_neuron" | |
| ) | |
| # Setup tokenizer (same as original model) | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| # Run inference | |
| inputs = tokenizer("Hamilton is considered to be the best musical of past years.", return_tensors="pt") | |
| logits = model(**inputs).logits | |
| print(model.config.id2label[logits.argmax().item()]) | |
| # 'POSITIVE' | |
| ``` | |
| The `NeuronModelForXXX` classes work as drop-in replacements for their `AutoModelForXXX` counterparts, making migration seamless. | |
| ## Next Steps | |
| Ready to dive deeper? Check out our comprehensive guides: | |
| - ๐ **[Getting Started](getting-started)** - Complete setup and installation | |
| - ๐๏ธ **[Training Tutorials](training_tutorials/notebooks)** - End-to-end training examples | |
| - ๐ง **[Export Guide](guides/export_model)** - Advanced model compilation options |
Xet Storage Details
- Size:
- 5.76 kB
- Xet hash:
- f8e1636aea3285e4da5237bc0ed6ae8668fa2bbc3bb8d254bb7f610867022ace
ยท
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.