File size: 2,108 Bytes
7a59163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
from model import NetFeat, NetClassifier  # Import feature extractor & classifier

# Define Clothing1M class labels
CLOTHING_CLASSES = [
    "T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket", 
    "Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit"
]

# Load the pre-trained model from Hugging Face Hub
def load_model():
    # Download the trained model from Hugging Face Hub
    model_filename = "CC_net.pt"

    # Initialize both feature extractor & classifier
    net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M')
    net_cls = NetClassifier(feat_dim=512, nb_cls=14)  # Feature dim from ResNet-18

    # Load saved model weights
    state_dict = torch.load(model_filename, map_location=torch.device('cpu'))

    # Load feature extractor weights
    if "feat" in state_dict:
        net_feat.load_state_dict(state_dict["feat"], strict=False)

    # Load classifier weights
    if "cls" in state_dict:
        net_cls.load_state_dict(state_dict["cls"], strict=False)

    # Set both models to evaluation mode
    net_feat.eval()
    net_cls.eval()

    return net_feat, net_cls

# Preprocess image for model input
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image).convert("RGB")  # Ensure the image is in RGB mode
    return transform(image).unsqueeze(0)  # Add batch dimension

# Run inference on a single image
def run_inference(image_path, net_feat, net_cls):
    image_tensor = preprocess_image(image_path)

    with torch.no_grad():
        feature_vector = net_feat(image_tensor)  # Extract features
        output = net_cls(feature_vector)  # Apply classification head

    predicted_index = output.argmax(dim=1).item()
    return CLOTHING_CLASSES[predicted_index]  # Return class name instead of index