Munazz's picture
Rename main.py to app.py
404bf24
import gradio as gr
import os
from PIL import Image
import torch
import torchvision
from torchvision import transforms, datasets # Correct import for torchvision
import sys
# Add the 'src' folder to the system path for module imports
sys.path.append('./src')
# Import from the 'src' folder
from model_jigsaw import mae_vit_small_patch16 # Import directly from src folder
# Static Variables
MODEL_PATH = "model/netBest.pth"
TEST_IMAGE_FOLDER = "test_images/"
# Class Mapping from Clothing1M dataset
class_names = {
0: "T-shirt",
1: "Shirt",
2: "Knitwear",
3: "Chiffon",
4: "Sweater",
5: "Hoodie",
6: "Windbreaker",
7: "Jacket",
8: "Down Coat",
9: "Suit",
10: "Shawl",
11: "Dress",
12: "Vest",
13: "Nightwear"
}
# Image Preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = image.convert('RGB')
image = transform(image).unsqueeze(0)
return image
# Load Model
model = mae_vit_small_patch16(nb_cls=14)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"))['net'])
model.eval()
# Load dataset mapping
val_dataset = torchvision.datasets.ImageFolder(root=TEST_IMAGE_FOLDER)
idx_to_class = {v: k for k, v in val_dataset.class_to_idx.items()}
def predict_single_image(image):
"""Predicts the class of an image using the exact logic from eval.py"""
image = preprocess_image(image)
with torch.no_grad():
outputs = model.forward_cls(image)
_, predicted_class = torch.max(outputs, 1)
# Get final class name
mapped_class_index = idx_to_class[predicted_class.item()]
final_class_name = class_names[int(mapped_class_index)]
return final_class_name
# Get all images from subfolders (recursively)
def get_image_paths(test_image_folder):
image_paths = []
for root, dirs, files in os.walk(test_image_folder):
for file in files:
if file.endswith((".jpg", ".png")):
image_paths.append(os.path.join(root, file))
return image_paths
test_image_files = get_image_paths(TEST_IMAGE_FOLDER)
def load_test_image(selected_image):
"""Loads the selected test image from test_images folder"""
return Image.open(selected_image)
# Create Gradio Interface with dynamic examples from test_images folder
demo = gr.Interface(
fn=predict_single_image,
inputs=gr.Image(type="pil", label="Upload an image"),
outputs=gr.Textbox(label="Predicted Category"),
title="Clothes Category Classifier",
description="Upload an image to classify its clothing category.",
examples=test_image_files # Use the correct paths for example images
)
demo.launch()