Homework2_NN / inference.py
kevinkyi's picture
Add inference.py
0081e70 verified
import torch, json
import torchvision
from torchvision import transforms, models
from PIL import Image
def build_model(arch, dropout, width, freeze_backbone, num_classes=2):
import torch.nn as nn
if arch == "smallcnn":
class SmallCNN(nn.Module):
def __init__(self, num_classes=2, dropout=0.2, width=32):
super().__init__()
c = width
self.features = nn.Sequential(
nn.Conv2d(3, c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(c, 2*c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(2*c, 4*c, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
)
self.head = nn.Sequential(nn.Flatten(), nn.Dropout(dropout), nn.Linear(4*c, num_classes))
def forward(self, x): return self.head(self.features(x))
return SmallCNN(num_classes=num_classes, dropout=dropout, width=width)
elif arch == "resnet18":
m = models.resnet18(weights=None) # weights not needed for inference after loading state_dict
in_features = m.fc.in_features
import torch.nn as nn
m.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_features, num_classes))
return m
elif arch == "mobilenet_v3_small":
m = models.mobilenet_v3_small(weights=None)
in_features = m.classifier[-1].in_features
import torch.nn as nn
m.classifier[-1] = nn.Linear(in_features, num_classes)
return m
else:
raise ValueError("Unknown arch")
def load_model(model_path="model_state.pt", config_path="config.json", device="cpu"):
with open(config_path) as f:
cfg = json.load(f)
model = build_model(cfg["arch"], cfg["dropout"], cfg["width"], cfg["freeze_backbone"], cfg["num_classes"])
state = torch.load(model_path, map_location=device)
model.load_state_dict(state, strict=True)
model.to(device).eval()
tfm = transforms.Compose([
transforms.Resize(int(cfg["img_size"]*1.14)),
transforms.CenterCrop(cfg["img_size"]),
transforms.ToTensor(),
transforms.Normalize(mean=cfg["mean"], std=cfg["std"]),
])
return model, tfm, cfg
def predict_image(image_path, model, tfm, device="cpu"):
img = Image.open(image_path).convert("RGB")
x = tfm(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1).cpu().numpy().ravel().tolist()
pred = int(logits.argmax(dim=1).item())
return pred, probs