| import os |
| import torch |
| import torch_neuronx |
|
|
| projector = torch.nn.Linear(768, 4096) |
| WEIGHT_ROOT = '/root/inf2_dir_0531/' |
| PROJECTOR_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'projector.pth') |
|
|
| projector_weight = torch.load(PROJECTOR_PATH) |
| projector.load_state_dict(projector_weight) |
|
|
| projector.eval() |
| example=torch.zeros((70, 32, 768), dtype=torch.float32) |
| neuron_projector = torch_neuronx.trace(projector, example) |
|
|
| filename = 'neuron_projector.pt' |
| filepath = os.path.join(WEIGHT_ROOT, filename, "inf2_weights") |
| torch.jit.save(neuron_projector, filepath) |