llava_mistral_0531 / conversions /neuron_projector_compile.py
root
feat
14bd491
import os
import torch
import torch_neuronx
projector = torch.nn.Linear(768, 4096)
WEIGHT_ROOT = '/root/inf2_dir_0531/'
PROJECTOR_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'projector.pth')
projector_weight = torch.load(PROJECTOR_PATH)
projector.load_state_dict(projector_weight)
projector.eval()
example=torch.zeros((70, 32, 768), dtype=torch.float32)
neuron_projector = torch_neuronx.trace(projector, example)
filename = 'neuron_projector.pt'
filepath = os.path.join(WEIGHT_ROOT, filename, "inf2_weights")
torch.jit.save(neuron_projector, filepath)