PrachiY commited on
Commit
dbc197c
Β·
verified Β·
1 Parent(s): ca464a2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +25 -44
app.py CHANGED
@@ -1,71 +1,52 @@
1
  import torch
2
- import torchvision.models as models
3
  import gradio as gr
4
- from huggingface_hub import hf_hub_download
5
  from PIL import Image
6
- from torchvision import transforms
7
 
8
- # βœ… Download model checkpoint from Hugging Face Hub
9
- model_path = hf_hub_download(
10
- repo_id="PrachiY/image-classification-model",
11
- filename="clothing1m.pth.tar"
12
- )
 
13
 
14
- # βœ… Load the Model
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- print(f"βœ… Using device: {device}")
17
-
18
- model = models.resnet50(pretrained=False)
19
- checkpoint = torch.load(model_path, map_location=device)
20
-
21
- if "model" in checkpoint:
22
- model.load_state_dict(checkpoint["model"], strict=False)
23
- elif "state_dict" in checkpoint:
24
- model.load_state_dict(checkpoint["state_dict"], strict=False)
25
- else:
26
- model.load_state_dict(checkpoint, strict=False)
27
 
28
- # Ensure correct output layer (21 classes for Clothing1M)
29
- model.fc = torch.nn.Linear(2048, 21)
 
 
 
30
  model.to(device)
31
  model.eval()
32
 
33
- # βœ… Define Clothing1M Class Labels
34
- class_labels = [
35
- "T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker",
36
- "Jacket", "Downcoat", "Suits", "Shawl", "Dress", "Vest", "Underwear",
37
- "Hat", "Sock", "Jeans", "Sweatpants", "Trousers", "Shorts", "Skirt"
38
- ]
39
-
40
- # βœ… Image Preprocessing
41
  def preprocess_image(image):
42
  transform = transforms.Compose([
43
  transforms.Resize((224, 224)),
44
  transforms.ToTensor(),
45
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
46
  ])
47
- image = transform(image).unsqueeze(0).to(device) # Ensure tensor is on GPU if available
48
- return image
49
 
50
- # βœ… Prediction Function
51
- def predict(image):
52
  image_tensor = preprocess_image(image)
53
  with torch.no_grad():
54
  output = model(image_tensor)
55
  predicted_class_idx = output.argmax(dim=1).item()
 
 
56
 
57
- if predicted_class_idx >= len(class_labels):
58
- return f"Predicted Class: Unknown (Index {predicted_class_idx} out of range)"
59
-
60
- return f"Predicted Class: {class_labels[predicted_class_idx]}"
61
-
62
- # βœ… Gradio Interface
63
  interface = gr.Interface(
64
- fn=predict,
65
  inputs=gr.Image(type="pil"),
66
  outputs="text",
67
- title="Clothing1M Image Classifier",
68
- description="Upload an image to classify it into one of 21 clothing categories."
69
  )
70
 
71
  if __name__ == "__main__":
 
1
  import torch
2
+ import torchvision.transforms as transforms
3
  import gradio as gr
4
+ from torchvision import models
5
  from PIL import Image
 
6
 
7
+ # Define Clothing1M class labels
8
+ clothing1m_classes = [
9
+ "T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker",
10
+ "Jacket", "Down Coat", "Suits", "Shawl", "Dress", "Vest", "Underwear", "Shorts",
11
+ "Trousers", "Jeans", "Leather Shoes", "Casual Shoes", "Sport Shoes", "Sandals"
12
+ ]
13
 
14
+ # βœ… Set device
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # βœ… Load model
18
+ model = models.resnet50(weights=None) # Ensure correct architecture
19
+ num_ftrs = model.fc.in_features
20
+ model.fc = torch.nn.Linear(num_ftrs, 21) # Match Clothing1M class count
21
+ model.load_state_dict(torch.load("model.pth", map_location=device)) # Load weights
22
  model.to(device)
23
  model.eval()
24
 
25
+ # βœ… Define image preprocessing
 
 
 
 
 
 
 
26
  def preprocess_image(image):
27
  transform = transforms.Compose([
28
  transforms.Resize((224, 224)),
29
  transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
  ])
32
+ return transform(image).unsqueeze(0).to(device)
 
33
 
34
+ # βœ… Define inference function
35
+ def classify_image(image):
36
  image_tensor = preprocess_image(image)
37
  with torch.no_grad():
38
  output = model(image_tensor)
39
  predicted_class_idx = output.argmax(dim=1).item()
40
+ predicted_class_name = clothing1m_classes[predicted_class_idx] if predicted_class_idx < len(clothing1m_classes) else "Unknown"
41
+ return f"Predicted Class: {predicted_class_name}"
42
 
43
+ # βœ… Create Gradio Interface
 
 
 
 
 
44
  interface = gr.Interface(
45
+ fn=classify_image,
46
  inputs=gr.Image(type="pil"),
47
  outputs="text",
48
+ title="Clothing1M Classifier",
49
+ description="Upload an image of clothing and get the predicted category."
50
  )
51
 
52
  if __name__ == "__main__":