File size: 632 Bytes
7c5440e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import logging
from typing import Union, List, Optional, Dict, Any, Literal
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from transformers_neuronx import MistralForSampling, GQA, NeuronConfig
import time
import math
model_name = './checkpoint-3000'
amp = 'bf16'
batch_size = 1
tp_degree = 8
n_positions = 8192
neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS)
model = MistralForSampling.from_pretrained(
model_name,
amp=amp,
batch_size=batch_size,
tp_degree=tp_degree,
n_positions=n_positions,
neuron_config=neuron_config
)
model.to_neuron() |