NewSpace1 / model_utils.py
Deba4's picture
Upload folder using huggingface_hub
7bf97d3 verified
# model_utils.py
import json
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
# Assumptions: ImageNet normalization, 224x224 input — change if yours differs
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
INPUT_SIZE = 224
# Preprocessing transform (resize->center crop->to tensor->normalize)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(INPUT_SIZE),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
def load_labels(path):
with open(path, "r", encoding="utf-8") as f:
return [line.strip() for line in f if line.strip()]
def load_remedies(path):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def build_model(num_classes, device, checkpoint_path):
# Build MobileNetV2 with custom classifier (must match training)
model = models.mobilenet_v2(pretrained=False)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(num_ftrs, num_classes)
# load weights
state = torch.load(checkpoint_path, map_location=device)
# if you saved state_dict only, this works:
if isinstance(state, dict) and ("state_dict" in state) and not any(k.startswith("module.") for k in state):
model.load_state_dict(state["state_dict"])
else:
try:
model.load_state_dict(state)
except Exception:
# attempt to handle possible 'module.' prefixes (from DataParallel)
new_state = {}
for k,v in state.items():
name = k.replace("module.", "") if k.startswith("module.") else k
new_state[name] = v
model.load_state_dict(new_state)
model.to(device)
model.eval()
return model
def load_model(checkpoint_path, labels_path, remedies_path, device):
labels = load_labels(labels_path)
remedies = load_remedies(remedies_path)
model = build_model(len(labels), device, checkpoint_path)
return model, labels, remedies
def predict(model, pil_image, labels, device, topk=3):
"""Return top-1 label, confidence, and topk list of (label, prob)."""
img_t = transform(pil_image).unsqueeze(0).to(device) # shape 1x3xHxW
with torch.no_grad():
outputs = model(img_t) # logits
probs = F.softmax(outputs, dim=1) # convert to probabilities
top_probs, top_idxs = probs.topk(topk, dim=1)
top_probs = top_probs.cpu().numpy()[0]
top_idxs = top_idxs.cpu().numpy()[0]
top_labels = [labels[i] for i in top_idxs]
return top_labels[0], float(top_probs[0]), list(zip(top_labels, top_probs.tolist()))