| import logging | |
| from typing import Union, List, Optional, Dict, Any, Literal | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from transformers_neuronx import MistralForSampling, GQA, NeuronConfig | |
| import time | |
| import math | |
| model_name = './checkpoint-3000' | |
| amp = 'bf16' | |
| batch_size = 1 | |
| tp_degree = 8 | |
| n_positions = 8192 | |
| neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS) | |
| model = MistralForSampling.from_pretrained( | |
| model_name, | |
| amp=amp, | |
| batch_size=batch_size, | |
| tp_degree=tp_degree, | |
| n_positions=n_positions, | |
| neuron_config=neuron_config | |
| ) | |
| model.to_neuron() |