File size: 632 Bytes
7c5440e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import logging
from typing import Union, List, Optional, Dict, Any, Literal
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from transformers_neuronx import MistralForSampling, GQA, NeuronConfig
import time
import math


model_name = './checkpoint-3000'
amp =  'bf16'
batch_size = 1
tp_degree = 8
n_positions = 8192
neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS)


model = MistralForSampling.from_pretrained(
    model_name,
    amp=amp,
    batch_size=batch_size,
    tp_degree=tp_degree,
    n_positions=n_positions,
    neuron_config=neuron_config
)
model.to_neuron()