Nested-Co-teaching / inference.py
Saahil-doryu's picture
Create inference.py
640a672 verified
import torch
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
from model import NetFeat, NetClassifier # Import feature extractor & classifier
# Define Clothing1M class labels
CLOTHING_CLASSES = [
"T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket",
"Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit"
]
# Load the pre-trained model from Hugging Face Hub
def load_model():
# Download the trained model from Hugging Face Hub
model_filename = "netBest.pth"
# Initialize both feature extractor & classifier
net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M')
net_cls = NetClassifier(feat_dim=512, nb_cls=14) # Feature dim from ResNet-18
# Load saved model weights
state_dict = torch.load(model_filename, map_location=torch.device('cpu'))
# Load feature extractor weights
if "feat" in state_dict:
net_feat.load_state_dict(state_dict["feat"], strict=False)
# Load classifier weights
if "cls" in state_dict:
net_cls.load_state_dict(state_dict["cls"], strict=False)
# Set both models to evaluation mode
net_feat.eval()
net_cls.eval()
return net_feat, net_cls
# Preprocess image for model input
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image).convert("RGB") # Ensure the image is in RGB mode
return transform(image).unsqueeze(0) # Add batch dimension
# Run inference on a single image
def run_inference(image_path, net_feat, net_cls):
image_tensor = preprocess_image(image_path)
with torch.no_grad():
feature_vector = net_feat(image_tensor) # Extract features
output = net_cls(feature_vector) # Apply classification head
predicted_index = output.argmax(dim=1).item()
return CLOTHING_CLASSES[predicted_index]