Spaces:
Build error
Build error
File size: 2,146 Bytes
b702dea cb229ed b702dea f5cf223 127caf5 cb229ed b702dea 127caf5 b702dea cb229ed b702dea cb229ed b702dea f5cf223 127caf5 b702dea 127caf5 f5cf223 b702dea f5cf223 b702dea f5cf223 127caf5 cb229ed f5cf223 b702dea 127caf5 56a43e8 f5cf223 127caf5 cb229ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from model import NetFeat, NetClassifier
CLOTHING_CLASSES = [
"T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket",
"Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit"
]
# Load the model
def load_model():
model_filename = 'netBest.pth' # Adjust the path as necessary
net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M')
net_cls = NetClassifier(feat_dim=512, nb_cls=14)
state_dict = torch.load(model_filename, map_location=torch.device('cpu'))
if "feat" in state_dict:
net_feat.load_state_dict(state_dict['feat'], strict=False)
if "cls" in state_dict:
net_cls.load_state_dict(state_dict['cls'], strict=False)
net_feat.eval()
net_cls.eval()
return net_feat, net_cls
# Preprocess image for model input
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image).convert("RGB")
return transform(image).unsqueeze(0)
def run_inference(image, net_feat, net_cls):
image_tensor = preprocess_image(image)
with torch.no_grad():
feature_vector = net_feat(image_tensor)
output = net_cls(feature_vector)
predicted_index = output.argmax(dim=1).item()
return CLOTHING_CLASSES[predicted_index]
net_feat, net_cls = load_model()
def classify_image(image):
return run_inference(image, net_feat, net_cls)
example_images = ["example.jpeg", "example2.webp","image2.jpg"]
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="filepath"), # Simple Image input
outputs=gr.Textbox(label="Predicted Clothing1M Class"),
title="Clothing1M Classifier",
description="Upload an image of clothing to classify it into one of 14 categories.",
examples=example_images
)
if __name__ == "__main__":
interface.launch() |