import logging from typing import Union, List, Optional, Dict, Any, Literal import torch import torch.nn.functional as F from transformers import AutoTokenizer from transformers_neuronx import MistralForSampling, GQA, NeuronConfig import time import math model_name = './checkpoint-3000' amp = 'bf16' batch_size = 1 tp_degree = 8 n_positions = 8192 neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS) model = MistralForSampling.from_pretrained( model_name, amp=amp, batch_size=batch_size, tp_degree=tp_degree, n_positions=n_positions, neuron_config=neuron_config ) model.to_neuron()