| |
| |
| import argparse |
| import logging |
| import time |
|
|
| import torch |
| from torchvision import transforms |
| from transformers import AutoImageProcessor, MobileNetV2ForImageClassification |
| from datasets import load_dataset |
| import torch_neuronx |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run MobileNetV2 on Neuron") |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="google/mobilenet_v2_1.0_224", |
| help="MobileNetV2 model name on Hugging Face Hub", |
| ) |
| args = parser.parse_args() |
|
|
| torch.set_default_dtype(torch.float32) |
| torch.manual_seed(42) |
|
|
| |
| dataset = load_dataset("huggingface/cats-image") |
| image = dataset["test"]["image"][0] |
|
|
| |
| processor = AutoImageProcessor.from_pretrained(args.model) |
| model = MobileNetV2ForImageClassification.from_pretrained( |
| args.model, torch_dtype=torch.float32, attn_implementation="eager" |
| ).eval() |
|
|
| |
| inputs = processor(images=image, return_tensors="pt") |
|
|
| |
| with torch.no_grad(): |
| _ = model(**inputs).logits |
|
|
| |
| model.forward = torch.compile(model.forward, backend="neuron", fullgraph=True) |
|
|
| |
| warmup_start = time.time() |
| with torch.no_grad(): |
| _ = model(**inputs) |
| warmup_time = time.time() - warmup_start |
|
|
| |
| run_start = time.time() |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| run_time = time.time() - run_start |
|
|
| |
| predicted_class_idx = logits.argmax(-1).item() |
| predicted_label = model.config.id2label[predicted_class_idx] |
|
|
| logger.info("Warmup: %.2f s, Run: %.4f s", warmup_time, run_time) |
| logger.info("Predicted label: %s", predicted_label) |
|
|
|
|
| if __name__ == "__main__": |
| main() |