Upload 3 files
Browse files- example.sh +1 -0
- example_inference.py +105 -0
- nas_lcd_demo_mb_imagnet_zebra_torchscript.pt +3 -0
example.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python3 example_inference.py --torchscript nas_lcd_demo_mb_imagnet_zebra_torchscript.pt --image_path ENO_S2_C05_R1_IMAG1122.JPG
|
example_inference.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision import datasets, transforms
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 9 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 10 |
+
HARDCODED_WNID_TO_IDX_101 = {'n01440764': 0, 'n01443537': 1, 'n01484850': 2, 'n01491361': 3, 'n01494475': 4, 'n01496331': 5, 'n01498041': 6, 'n01514668': 7, 'n01514859': 8, 'n01531178': 9, 'n01537544': 10, 'n01560419': 11, 'n01582220': 12, 'n01592084': 13, 'n01601694': 14, 'n01608432': 15, 'n01614925': 16, 'n01622779': 17, 'n01630670': 18, 'n01632458': 19, 'n01632777': 20, 'n01644900': 21, 'n01664065': 22, 'n01665541': 23, 'n01667114': 24, 'n01667778': 25, 'n01675722': 26, 'n01677366': 27, 'n01685808': 28, 'n01687978': 29, 'n01693334': 30, 'n01695060': 31, 'n01698640': 32, 'n01728572': 33, 'n01729322': 34, 'n01729977': 35, 'n01734418': 36, 'n01735189': 37, 'n01739381': 38, 'n01740131': 39, 'n01742172': 40, 'n01749939': 41, 'n01751748': 42, 'n01753488': 43, 'n01755581': 44, 'n01756291': 45, 'n01770081': 46, 'n01770393': 47, 'n01773157': 48, 'n01773549': 49, 'n01773797': 50, 'n01774384': 51, 'n01774750': 52, 'n01775062': 53, 'n01776313': 54, 'n01795545': 55, 'n01796340': 56, 'n01798484': 57, 'n01806143': 58, 'n01818515': 59, 'n01819313': 60, 'n01820546': 61, 'n01824575': 62, 'n01828970': 63, 'n01829413': 64, 'n01833805': 65, 'n01843383': 66, 'n01847000': 67, 'n01855672': 68, 'n01860187': 69, 'n01877812': 70, 'n01883070': 71, 'n01910747': 72, 'n01914609': 73, 'n01924916': 74, 'n01930112': 75, 'n01943899': 76, 'n01944390': 77, 'n01950731': 78, 'n01955084': 79, 'n01968897': 80, 'n01978287': 81, 'n01978455': 82, 'n01984695': 83, 'n01985128': 84, 'n01986214': 85, 'n02002556': 86, 'n02006656': 87, 'n02007558': 88, 'n02011460': 89, 'n02012849': 90, 'n02013706': 91, 'n02018207': 92, 'n02018795': 93, 'n02027492': 94, 'n02028035': 95, 'n02037110': 96, 'n02051845': 97, 'n02058221': 98, 'n02077923': 99, 'n02391049': 100}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def preprocess_image(image_path, input_image_size):
|
| 14 |
+
image = Image.open(image_path).convert("RGB")
|
| 15 |
+
transform_list = []
|
| 16 |
+
if image.size[0] != input_image_size or image.size[1] != input_image_size:
|
| 17 |
+
transform_list.extend([
|
| 18 |
+
transforms.Resize(input_image_size, interpolation=Image.BICUBIC),
|
| 19 |
+
transforms.CenterCrop(input_image_size),
|
| 20 |
+
])
|
| 21 |
+
transform_list.extend([
|
| 22 |
+
transforms.ToTensor(),
|
| 23 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 24 |
+
])
|
| 25 |
+
transform = transforms.Compose(transform_list)
|
| 26 |
+
tensor = transform(image).unsqueeze(0)
|
| 27 |
+
return tensor
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_wnid_to_name(cls_map_path):
|
| 31 |
+
if not cls_map_path:
|
| 32 |
+
return None
|
| 33 |
+
wnid_to_name = {}
|
| 34 |
+
with open(cls_map_path, "r", encoding="utf-8") as handle:
|
| 35 |
+
for line in handle:
|
| 36 |
+
line = line.strip()
|
| 37 |
+
if not line:
|
| 38 |
+
continue
|
| 39 |
+
parts = line.split()
|
| 40 |
+
if len(parts) < 3:
|
| 41 |
+
continue
|
| 42 |
+
wnid = parts[0]
|
| 43 |
+
class_name = " ".join(parts[2:])
|
| 44 |
+
wnid_to_name[wnid] = class_name
|
| 45 |
+
return wnid_to_name if wnid_to_name else None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_idx_to_name_from_val_dir(val_dir, wnid_to_name):
|
| 49 |
+
if not val_dir:
|
| 50 |
+
idx_to_wnid = {idx: wnid for wnid, idx in HARDCODED_WNID_TO_IDX_101.items()}
|
| 51 |
+
else:
|
| 52 |
+
dataset = datasets.ImageFolder(val_dir)
|
| 53 |
+
idx_to_wnid = {v: k for k, v in dataset.class_to_idx.items()}
|
| 54 |
+
idx_to_name = {}
|
| 55 |
+
for idx, wnid in idx_to_wnid.items():
|
| 56 |
+
idx_to_name[idx] = wnid_to_name.get(wnid, wnid) if wnid_to_name else wnid
|
| 57 |
+
return idx_to_name if idx_to_name else None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def topk_from_logits(logits, idx_to_class, k=5):
|
| 61 |
+
probs = torch.softmax(logits, dim=1)
|
| 62 |
+
values, indices = torch.topk(probs, k=k, dim=1)
|
| 63 |
+
values = values.squeeze(0).tolist()
|
| 64 |
+
indices = indices.squeeze(0).tolist()
|
| 65 |
+
results = []
|
| 66 |
+
for score, idx in zip(values, indices):
|
| 67 |
+
cls_name = idx_to_class.get(idx, str(idx)) if idx_to_class else str(idx)
|
| 68 |
+
results.append((idx, cls_name, score))
|
| 69 |
+
return results
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main():
|
| 73 |
+
parser = argparse.ArgumentParser(description="TorchScript single-image inference.")
|
| 74 |
+
parser.add_argument("--torchscript", type=str, required=True,
|
| 75 |
+
help="Path to TorchScript .pt file")
|
| 76 |
+
parser.add_argument("--image_path", type=str, required=True,
|
| 77 |
+
help="Path to input image")
|
| 78 |
+
parser.add_argument("--input_image_size", type=int, default=224)
|
| 79 |
+
parser.add_argument("--val_dir", type=str,
|
| 80 |
+
default=None,
|
| 81 |
+
help="Val dir to derive class index -> wnid mapping")
|
| 82 |
+
parser.add_argument("--cls_map_path", type=str,
|
| 83 |
+
default="/scratch/general/vast/j.yan/nas_tvm/cls_map.txt",
|
| 84 |
+
help="Path to cls_map.txt for wnid -> class name mapping")
|
| 85 |
+
parser.add_argument("--topk", type=int, default=1)
|
| 86 |
+
args = parser.parse_args()
|
| 87 |
+
|
| 88 |
+
model = torch.jit.load(args.torchscript, map_location="cpu")
|
| 89 |
+
model.eval()
|
| 90 |
+
|
| 91 |
+
input_tensor = preprocess_image(args.image_path, args.input_image_size)
|
| 92 |
+
wnid_to_name = load_wnid_to_name(args.cls_map_path)
|
| 93 |
+
idx_to_class = load_idx_to_name_from_val_dir(args.val_dir, wnid_to_name)
|
| 94 |
+
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
logits = model(input_tensor)
|
| 97 |
+
|
| 98 |
+
print("[torchscript] top-{}:".format(args.topk))
|
| 99 |
+
for _, cls_name, score in topk_from_logits(logits, idx_to_class, k=args.topk):
|
| 100 |
+
print(f" class={cls_name} prob={score:.6f}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
| 105 |
+
|
nas_lcd_demo_mb_imagnet_zebra_torchscript.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92f4ef6205e1e440b725be9a4dc4fef36230a733a0a079627440c66c97cbbedb
|
| 3 |
+
size 5744755
|