inf2_dir / app /tmp /mistral_standalone.py
root
feat: update
7c5440e
import logging
from typing import Union, List, Optional, Dict, Any, Literal
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from transformers_neuronx import MistralForSampling, GQA, NeuronConfig
import time
import math
model_name = './checkpoint-3000'
amp = 'bf16'
batch_size = 1
tp_degree = 8
n_positions = 8192
neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS)
model = MistralForSampling.from_pretrained(
model_name,
amp=amp,
batch_size=batch_size,
tp_degree=tp_degree,
n_positions=n_positions,
neuron_config=neuron_config
)
model.to_neuron()