| |
| import json |
| import torch |
| import torch.nn.functional as F |
| from torchvision import models, transforms |
| from PIL import Image |
|
|
| |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
| INPUT_SIZE = 224 |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(INPUT_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) |
| ]) |
|
|
| def load_labels(path): |
| with open(path, "r", encoding="utf-8") as f: |
| return [line.strip() for line in f if line.strip()] |
|
|
| def load_remedies(path): |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
| def build_model(num_classes, device, checkpoint_path): |
| |
| model = models.mobilenet_v2(pretrained=False) |
| num_ftrs = model.classifier[1].in_features |
| model.classifier[1] = torch.nn.Linear(num_ftrs, num_classes) |
| |
| state = torch.load(checkpoint_path, map_location=device) |
| |
| if isinstance(state, dict) and ("state_dict" in state) and not any(k.startswith("module.") for k in state): |
| model.load_state_dict(state["state_dict"]) |
| else: |
| try: |
| model.load_state_dict(state) |
| except Exception: |
| |
| new_state = {} |
| for k,v in state.items(): |
| name = k.replace("module.", "") if k.startswith("module.") else k |
| new_state[name] = v |
| model.load_state_dict(new_state) |
| model.to(device) |
| model.eval() |
| return model |
|
|
| def load_model(checkpoint_path, labels_path, remedies_path, device): |
| labels = load_labels(labels_path) |
| remedies = load_remedies(remedies_path) |
| model = build_model(len(labels), device, checkpoint_path) |
| return model, labels, remedies |
|
|
| def predict(model, pil_image, labels, device, topk=3): |
| """Return top-1 label, confidence, and topk list of (label, prob).""" |
| img_t = transform(pil_image).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| outputs = model(img_t) |
| probs = F.softmax(outputs, dim=1) |
| top_probs, top_idxs = probs.topk(topk, dim=1) |
| top_probs = top_probs.cpu().numpy()[0] |
| top_idxs = top_idxs.cpu().numpy()[0] |
| top_labels = [labels[i] for i in top_idxs] |
| return top_labels[0], float(top_probs[0]), list(zip(top_labels, top_probs.tolist())) |
|
|