Ubuntu
tests
5ee43e9
#!/usr/bin/env python3
# SAM encoder on Neuron – constant-shape, no lambda
import argparse
import logging
import time
import torch
from transformers import SamProcessor, SamModel
from PIL import Image
import torch_neuronx # guarantees Neuron backend
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="SAM encoder on Neuron (full graph)")
parser.add_argument("--model", default="facebook/sam-vit-base")
args = parser.parse_args()
torch.manual_seed(42)
torch.set_default_dtype(torch.float32)
# load processor & model
processor = SamProcessor.from_pretrained(args.model)
model = SamModel.from_pretrained(args.model, attn_implementation="eager").eval()
# dummy 224×224 RGB image
dummy_image = Image.new("RGB", (224, 224), color="red")
# constant-shape inputs (no points → encoder only)
inputs = processor(images=dummy_image, return_tensors="pt")
# pre-run to lock shapes
with torch.no_grad():
_ = model.get_image_embeddings(**inputs)
# compile encoder forward (full graph)
model.get_image_embeddings = torch.compile(
model.get_image_embeddings, backend="neuron", fullgraph=True
)
# warmup
start = time.time()
with torch.no_grad():
_ = model.get_image_embeddings(**inputs)
logger.info("Warmup: %.3f s", time.time() - start)
# benchmark
start = time.time()
with torch.no_grad():
embeddings = model.get_image_embeddings(**inputs)
logger.info("Run: %.3f s", time.time() - start)
logger.info("Embedding shape: %s", embeddings.shape) # [1, 256, 64, 64]
if __name__ == "__main__":
main()