#!/usr/bin/env python3 # SAM encoder on Neuron – constant-shape, no lambda import argparse import logging import time import torch from transformers import SamProcessor, SamModel from PIL import Image import torch_neuronx # guarantees Neuron backend logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser(description="SAM encoder on Neuron (full graph)") parser.add_argument("--model", default="facebook/sam-vit-base") args = parser.parse_args() torch.manual_seed(42) torch.set_default_dtype(torch.float32) # load processor & model processor = SamProcessor.from_pretrained(args.model) model = SamModel.from_pretrained(args.model, attn_implementation="eager").eval() # dummy 224×224 RGB image dummy_image = Image.new("RGB", (224, 224), color="red") # constant-shape inputs (no points → encoder only) inputs = processor(images=dummy_image, return_tensors="pt") # pre-run to lock shapes with torch.no_grad(): _ = model.get_image_embeddings(**inputs) # compile encoder forward (full graph) model.get_image_embeddings = torch.compile( model.get_image_embeddings, backend="neuron", fullgraph=True ) # warmup start = time.time() with torch.no_grad(): _ = model.get_image_embeddings(**inputs) logger.info("Warmup: %.3f s", time.time() - start) # benchmark start = time.time() with torch.no_grad(): embeddings = model.get_image_embeddings(**inputs) logger.info("Run: %.3f s", time.time() - start) logger.info("Embedding shape: %s", embeddings.shape) # [1, 256, 64, 64] if __name__ == "__main__": main()