import gradio as gr import os from PIL import Image import torch import torchvision from torchvision import transforms, datasets # Correct import for torchvision import sys # Add the 'src' folder to the system path for module imports sys.path.append('./src') # Import from the 'src' folder from model_jigsaw import mae_vit_small_patch16 # Import directly from src folder # Static Variables MODEL_PATH = "model/netBest.pth" TEST_IMAGE_FOLDER = "test_images/" # Class Mapping from Clothing1M dataset class_names = { 0: "T-shirt", 1: "Shirt", 2: "Knitwear", 3: "Chiffon", 4: "Sweater", 5: "Hoodie", 6: "Windbreaker", 7: "Jacket", 8: "Down Coat", 9: "Suit", 10: "Shawl", 11: "Dress", 12: "Vest", 13: "Nightwear" } # Image Preprocessing def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = image.convert('RGB') image = transform(image).unsqueeze(0) return image # Load Model model = mae_vit_small_patch16(nb_cls=14) model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"))['net']) model.eval() # Load dataset mapping val_dataset = torchvision.datasets.ImageFolder(root=TEST_IMAGE_FOLDER) idx_to_class = {v: k for k, v in val_dataset.class_to_idx.items()} def predict_single_image(image): """Predicts the class of an image using the exact logic from eval.py""" image = preprocess_image(image) with torch.no_grad(): outputs = model.forward_cls(image) _, predicted_class = torch.max(outputs, 1) # Get final class name mapped_class_index = idx_to_class[predicted_class.item()] final_class_name = class_names[int(mapped_class_index)] return final_class_name # Get all images from subfolders (recursively) def get_image_paths(test_image_folder): image_paths = [] for root, dirs, files in os.walk(test_image_folder): for file in files: if file.endswith((".jpg", ".png")): image_paths.append(os.path.join(root, file)) return image_paths test_image_files = get_image_paths(TEST_IMAGE_FOLDER) def load_test_image(selected_image): """Loads the selected test image from test_images folder""" return Image.open(selected_image) # Create Gradio Interface with dynamic examples from test_images folder demo = gr.Interface( fn=predict_single_image, inputs=gr.Image(type="pil", label="Upload an image"), outputs=gr.Textbox(label="Predicted Category"), title="Clothes Category Classifier", description="Upload an image to classify its clothing category.", examples=test_image_files # Use the correct paths for example images ) demo.launch()