| | import gradio as gr |
| | import os |
| | from PIL import Image |
| | import torch |
| | import torchvision |
| | from torchvision import transforms, datasets |
| | import sys |
| |
|
| |
|
| | |
| | sys.path.append('./src') |
| |
|
| | |
| | from model_jigsaw import mae_vit_small_patch16 |
| |
|
| |
|
| | |
| | MODEL_PATH = "model/netBest.pth" |
| | TEST_IMAGE_FOLDER = "test_images/" |
| |
|
| | |
| | class_names = { |
| | 0: "T-shirt", |
| | 1: "Shirt", |
| | 2: "Knitwear", |
| | 3: "Chiffon", |
| | 4: "Sweater", |
| | 5: "Hoodie", |
| | 6: "Windbreaker", |
| | 7: "Jacket", |
| | 8: "Down Coat", |
| | 9: "Suit", |
| | 10: "Shawl", |
| | 11: "Dress", |
| | 12: "Vest", |
| | 13: "Nightwear" |
| | } |
| |
|
| | |
| | def preprocess_image(image): |
| | transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | image = image.convert('RGB') |
| | image = transform(image).unsqueeze(0) |
| | return image |
| |
|
| | |
| | model = mae_vit_small_patch16(nb_cls=14) |
| | model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu"))['net']) |
| | model.eval() |
| |
|
| | |
| | val_dataset = torchvision.datasets.ImageFolder(root=TEST_IMAGE_FOLDER) |
| | idx_to_class = {v: k for k, v in val_dataset.class_to_idx.items()} |
| |
|
| | def predict_single_image(image): |
| | """Predicts the class of an image using the exact logic from eval.py""" |
| | image = preprocess_image(image) |
| |
|
| | with torch.no_grad(): |
| | outputs = model.forward_cls(image) |
| |
|
| | _, predicted_class = torch.max(outputs, 1) |
| |
|
| | |
| | mapped_class_index = idx_to_class[predicted_class.item()] |
| | final_class_name = class_names[int(mapped_class_index)] |
| |
|
| | return final_class_name |
| |
|
| | |
| | def get_image_paths(test_image_folder): |
| | image_paths = [] |
| | for root, dirs, files in os.walk(test_image_folder): |
| | for file in files: |
| | if file.endswith((".jpg", ".png")): |
| | image_paths.append(os.path.join(root, file)) |
| | return image_paths |
| |
|
| | test_image_files = get_image_paths(TEST_IMAGE_FOLDER) |
| |
|
| | def load_test_image(selected_image): |
| | """Loads the selected test image from test_images folder""" |
| | return Image.open(selected_image) |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=predict_single_image, |
| | inputs=gr.Image(type="pil", label="Upload an image"), |
| | outputs=gr.Textbox(label="Predicted Category"), |
| | title="Clothes Category Classifier", |
| | description="Upload an image to classify its clothing category.", |
| | examples=test_image_files |
| | ) |
| |
|
| | demo.launch() |
| |
|