File size: 4,101 Bytes
0141596
 
 
35bc32f
0141596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
946f89e
0141596
35bc32f
0141596
 
35bc32f
0141596
 
35bc32f
0141596
 
 
35bc32f
5b6f9be
0141596
35bc32f
0141596
5b6f9be
0141596
35bc32f
0141596
 
5d098b6
0141596
 
5d098b6
0141596
 
 
5d098b6
0141596
 
 
 
 
 
 
2d9baba
0141596
 
2d9baba
0141596
 
 
 
2d9baba
0141596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b6f9be
0141596
 
5b6f9be
0141596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120



import torch
import gradio as gr
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import requests
from io import BytesIO
from resnet import SupCEResNet

# Define class labels
class_labels = [
    "T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater",
    "Hoodie", "Windbreaker", "Jacket", "Down Coat", "Suit",
    "Shawl", "Dress", "Vest", "Underwear"
]

# Load model from Hugging Face Hub
def load_model_from_huggingface(repo_id="tfarhan10/Clothing1M-Pretrained-ResNet50", filename="model.pth"):
    try:
        print("Downloading model from Hugging Face...")
        checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)

        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'),weights_only=False)

        # Extract state_dict if stored in a dictionary
        if isinstance(checkpoint, dict) and "model" in checkpoint:
            state_dict = checkpoint["model"]
        else:
            state_dict = checkpoint

        # Fix "module." prefix issue
        new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

        # Initialize model
        model = SupCEResNet(name='resnet50', num_classes=14, pool=True)

        # Load weights
        model.load_state_dict(new_state_dict, strict=False)  # `strict=False` allows minor mismatches
        model.eval()  # Set model to evaluation mode

        print("Model loaded successfully from Hugging Face!")
        return model

    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Load the model
model = load_model_from_huggingface()

def classify_image(image):
    """Process and classify an uploaded PIL image accurately."""

    # Ensure image is in RGB format
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Define preprocessing transformations (same as training)
    transform_test = transforms.Compose([
        transforms.Resize(256),  # Resize the shorter side to 256
        transforms.CenterCrop(224),  # Center crop to 224x224 (expected input size)
        transforms.ToTensor(),  # Convert to Tensor
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Normalize
    ])

    # Apply transformations
    image_tensor = transform_test(image).unsqueeze(0)  # Add batch dimension

    # Ensure tensor is on the same device as model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    image_tensor = image_tensor.to(device)

    # Run inference
    with torch.no_grad():
        output = model(image_tensor)
        _, pred = torch.max(output, 1)  # Get predicted class index

    # Map predicted class index to label
    predicted_label = class_labels[pred.item()]
    return f"Predicted Category: {predicted_label}"

# Load example image from Hugging Face repository
example_url = "https://huggingface.co/tfarhan10/Clothing1M-Pretrained-ResNet50/resolve/main/content/drive/MyDrive/CS5930/download.jpeg"

def load_example_image():
    """Download and return an example image from Hugging Face"""
    try:
        response = requests.get(example_url)
        if response.status_code == 200:
            return Image.open(BytesIO(response.content)).convert("RGB")
        else:
            print("Failed to fetch example image.")
            return None
    except Exception as e:
        print(f"Error loading example image: {e}")
        return None

# Example image
example_image = load_example_image()

# Create Gradio Interface
interface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),  # Accept image input
    outputs="text",
    title="Clothing Image Classifier",
    description="Upload an image or use the example below. The model will classify it into one of 14 clothing categories.",
    allow_flagging="never",  # Disable flagging feature
    examples=[[example_image]] if example_image else None  # Use example image if available
)

# Launch the app
if __name__ == "__main__":
    interface.launch()