PrachiY commited on
Commit
0f2b2af
·
verified ·
1 Parent(s): 27eeb7b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.models as models
3
+ import gradio as gr
4
+ from huggingface_hub import hf_hub_download
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ # ✅ Download model checkpoint from Hugging Face Hub
9
+ model_path = hf_hub_download(
10
+ repo_id="PrachiY/image-classification-model",
11
+ filename="clothing1m.pth.tar"
12
+ )
13
+
14
+ # ✅ Load the Model
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model = models.resnet50(pretrained=False)
17
+
18
+ checkpoint = torch.load(model_path, map_location=device)
19
+
20
+ if "model" in checkpoint:
21
+ model.load_state_dict(checkpoint["model"], strict=False)
22
+ elif "state_dict" in checkpoint:
23
+ model.load_state_dict(checkpoint["state_dict"], strict=False)
24
+ else:
25
+ model.load_state_dict(checkpoint, strict=False)
26
+
27
+ model.fc = torch.nn.Linear(2048, 21)
28
+ model.to(device)
29
+ model.eval()
30
+
31
+ # ✅ Define Clothing1M Class Labels
32
+ class_labels = [
33
+ "T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker",
34
+ "Jacket", "Downcoat", "Suits", "Shawl", "Dress", "Vest", "Underwear",
35
+ "Hat", "Sock", "Jeans", "Sweatpants", "Trousers", "Shorts", "Skirt"
36
+ ]
37
+
38
+ # ✅ Image Preprocessing
39
+ def preprocess_image(image):
40
+ transform = transforms.Compose([
41
+ transforms.Resize((224, 224)),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
44
+ ])
45
+ return transform(image).unsqueeze(0).to(device)
46
+
47
+ # ✅ Prediction Function
48
+ def predict(image):
49
+ image_tensor = preprocess_image(image)
50
+ with torch.no_grad():
51
+ output = model(image_tensor)
52
+ predicted_class_idx = output.argmax(dim=1).item()
53
+
54
+ if predicted_class_idx >= len(class_labels):
55
+ return f"Predicted Class: Unknown (Index {predicted_class_idx} out of range)"
56
+
57
+ return f"Predicted Class: {class_labels[predicted_class_idx]}"
58
+
59
+ # ✅ Gradio Interface
60
+ interface = gr.Interface(
61
+ fn=predict,
62
+ inputs=gr.Image(type="pil"),
63
+ outputs="text",
64
+ title="Clothing1M Image Classifier",
65
+ description="Upload an image to classify it into one of 21 clothing categories."
66
+ )
67
+
68
+ if __name__ == "__main__":
69
+ interface.launch()