File size: 2,146 Bytes
b702dea
 
cb229ed
b702dea
f5cf223
 
127caf5
cb229ed
b702dea
 
 
 
127caf5
b702dea
 
 
 
 
 
 
 
 
 
 
cb229ed
b702dea
 
 
cb229ed
b702dea
 
 
 
 
 
 
 
f5cf223
 
127caf5
b702dea
 
 
 
 
 
 
127caf5
f5cf223
b702dea
 
f5cf223
b702dea
 
 
f5cf223
 
127caf5
cb229ed
f5cf223
b702dea
127caf5
56a43e8
f5cf223
 
127caf5
 
 
cb229ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from model import NetFeat, NetClassifier  


CLOTHING_CLASSES = [
    "T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket",
    "Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit"
]

# Load the model
def load_model():
    model_filename = 'netBest.pth'  # Adjust the path as necessary
    net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M')
    net_cls = NetClassifier(feat_dim=512, nb_cls=14)
    
    state_dict = torch.load(model_filename, map_location=torch.device('cpu'))
    if "feat" in state_dict:
        net_feat.load_state_dict(state_dict['feat'], strict=False)
    if "cls" in state_dict:
        net_cls.load_state_dict(state_dict['cls'], strict=False)

    net_feat.eval()
    net_cls.eval()
    return net_feat, net_cls

# Preprocess image for model input
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image).convert("RGB")
    return transform(image).unsqueeze(0)  


def run_inference(image, net_feat, net_cls):
    image_tensor = preprocess_image(image)
    with torch.no_grad():
        feature_vector = net_feat(image_tensor)
        output = net_cls(feature_vector)
    predicted_index = output.argmax(dim=1).item()
    return CLOTHING_CLASSES[predicted_index]


net_feat, net_cls = load_model()


def classify_image(image):
    return run_inference(image, net_feat, net_cls)

example_images = ["example.jpeg", "example2.webp","image2.jpg"]

interface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="filepath"),  # Simple Image input
    outputs=gr.Textbox(label="Predicted Clothing1M Class"),
    title="Clothing1M Classifier",
    description="Upload an image of clothing to classify it into one of 14 categories.",
    examples=example_images

)

if __name__ == "__main__":
    interface.launch()