| |
| |
| import argparse |
| import logging |
| import time |
|
|
| import torch |
| from transformers import AutoImageProcessor, SwinForImageClassification |
| from datasets import load_dataset |
| import torch_neuronx |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Swin on Neuron (full graph)") |
| parser.add_argument("--model", default="microsoft/swin-tiny-patch4-window7-224") |
| args = parser.parse_args() |
|
|
| torch.manual_seed(42) |
| torch.set_default_dtype(torch.float32) |
|
|
| |
| dataset = load_dataset("huggingface/cats-image") |
| image = dataset["test"]["image"][0] |
|
|
| |
| processor = AutoImageProcessor.from_pretrained(args.model) |
| model = SwinForImageClassification.from_pretrained( |
| args.model, torch_dtype=torch.float32, attn_implementation="eager" |
| ).eval() |
|
|
| |
| inputs = processor(images=image, return_tensors="pt") |
|
|
| |
| with torch.no_grad(): |
| _ = model(**inputs).logits |
|
|
| |
| model.forward = torch.compile(model.forward, backend="neuron", fullgraph=True) |
|
|
| |
| warmup_start = time.time() |
| with torch.no_grad(): |
| _ = model(**inputs) |
| logger.info("Warmup: %.3f s", time.time() - warmup_start) |
|
|
| |
| run_start = time.time() |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| run_time = time.time() - run_start |
|
|
| |
| predicted_class_idx = logits.argmax(-1).item() |
| predicted_label = model.config.id2label[predicted_class_idx] |
|
|
| logger.info("Run: %.3f s", run_time) |
| logger.info("Predicted label: %s", predicted_label) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
| """ |
| /usr/local/lib/python3.10/site-packages/transformers/models/swin/modeling_swin.py:611:0: error: failed to legalize operation 'torch.aten.fill.Tensor' |
| /usr/local/lib/python3.10/site-packages/transformers/models/swin/modeling_swin.py:662:0: note: called from |
| /usr/local/lib/python3.10/site-packages/transformers/models/swin/modeling_swin.py:736:0: note: called from |
| /usr/local/lib/python3.10/site-packages/transformers/modeling_layers.py:94:0: note: called from |
| /usr/local/lib/python3.10/site-packages/transformers/models/swin/modeling_swin.py:806:0: note: called from |
| /usr/local/lib/python3.10/site-packages/transformers/models/swin/modeling_swin.py:945:0: note: called from |
| /usr/local/lib/python3.10/site-packages/transformers/models/swin/modeling_swin.py:1139:0: note: called from |
| /usr/local/lib/python3.10/site-packages/transformers/models/swin/modeling_swin.py:611:0: note: see current operation: %1014 = "torch.aten.fill.Tensor"(%1013, %778) : (!torch.vtensor<[1,49,49,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[1,49,49,1],f32> |
| """ |