Clothpredict / app.py
Saahil-doryu's picture
Update app.py
56a43e8 verified
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from model import NetFeat, NetClassifier
CLOTHING_CLASSES = [
"T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket",
"Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit"
]
# Load the model
def load_model():
model_filename = 'netBest.pth' # Adjust the path as necessary
net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M')
net_cls = NetClassifier(feat_dim=512, nb_cls=14)
state_dict = torch.load(model_filename, map_location=torch.device('cpu'))
if "feat" in state_dict:
net_feat.load_state_dict(state_dict['feat'], strict=False)
if "cls" in state_dict:
net_cls.load_state_dict(state_dict['cls'], strict=False)
net_feat.eval()
net_cls.eval()
return net_feat, net_cls
# Preprocess image for model input
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image).convert("RGB")
return transform(image).unsqueeze(0)
def run_inference(image, net_feat, net_cls):
image_tensor = preprocess_image(image)
with torch.no_grad():
feature_vector = net_feat(image_tensor)
output = net_cls(feature_vector)
predicted_index = output.argmax(dim=1).item()
return CLOTHING_CLASSES[predicted_index]
net_feat, net_cls = load_model()
def classify_image(image):
return run_inference(image, net_feat, net_cls)
example_images = ["example.jpeg", "example2.webp","image2.jpg"]
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="filepath"), # Simple Image input
outputs=gr.Textbox(label="Predicted Clothing1M Class"),
title="Clothing1M Classifier",
description="Upload an image of clothing to classify it into one of 14 categories.",
examples=example_images
)
if __name__ == "__main__":
interface.launch()