Saahil-doryu commited on
Commit
b702dea
·
verified ·
1 Parent(s): cb229ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -20
app.py CHANGED
@@ -1,35 +1,67 @@
1
- import gradio as gr
 
2
  from PIL import Image
3
- from inference import load_model, run_inference
4
 
 
5
 
6
- net_feat, net_cls = load_model()
 
 
 
 
7
 
8
- def classify_image(image):
9
- """Classify the image using the loaded model and return the predicted class."""
10
- try:
11
- prediction = run_inference(image, net_feat, net_cls)
12
- return prediction
13
- except Exception as e:
14
- return f"An error occurred: {str(e)}"
 
 
 
 
15
 
 
 
 
16
 
17
- example_images = [
18
- "sample1.jpg", "sample2.jpg", "sample3.jpg",
19
- "sample4.jpg", "sample5.jpg", "sample6.jpg"
20
- ]
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
23
  interface = gr.Interface(
24
  fn=classify_image,
25
- inputs=gr.inputs.Image(type="pil", label="Upload an Image"),
26
- outputs=gr.outputs.Textbox(label="Predicted Clothing1M Class"),
27
  title="Clothing1M Classifier",
28
- description="Upload an image of clothing, and the classifier will identify its category among 14 different types, such as T-shirts, Dresses, Jackets, etc.",
29
- examples=example_images,
30
- theme="default"
31
  )
32
 
33
-
34
  if __name__ == "__main__":
35
  interface.launch()
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
  from PIL import Image
4
+ import gradio as gr
5
 
6
+ from model import NetFeat, NetClassifier # Ensure this imports correctly based on your setup
7
 
8
+ # Define class labels for clothing classification
9
+ CLOTHING_CLASSES = [
10
+ "T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket",
11
+ "Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit"
12
+ ]
13
 
14
+ # Load the model
15
+ def load_model():
16
+ model_filename = 'netBest.pth' # Adjust the path as necessary
17
+ net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M')
18
+ net_cls = NetClassifier(feat_dim=512, nb_cls=14)
19
+
20
+ state_dict = torch.load(model_filename, map_location=torch.device('cpu'))
21
+ if "feat" in state_dict:
22
+ net_feat.load_state_dict(state_dict['feat'], strict=False)
23
+ if "cls" in state_dict:
24
+ net_cls.load_state_dict(state_dict['cls'], strict=False)
25
 
26
+ net_feat.eval()
27
+ net_cls.eval()
28
+ return net_feat, net_cls
29
 
30
+ # Preprocess image for model input
31
+ def preprocess_image(image):
32
+ transform = transforms.Compose([
33
+ transforms.Resize((224, 224)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
36
+ ])
37
+ image = Image.open(image).convert("RGB")
38
+ return transform(image).unsqueeze(0) # Add batch dimension
39
 
40
+ # Run inference on the image
41
+ def run_inference(image, net_feat, net_cls):
42
+ image_tensor = preprocess_image(image)
43
+ with torch.no_grad():
44
+ feature_vector = net_feat(image_tensor)
45
+ output = net_cls(feature_vector)
46
+ predicted_index = output.argmax(dim=1).item()
47
+ return CLOTHING_CLASSES[predicted_index]
48
 
49
+ # Load models
50
+ net_feat, net_cls = load_model()
51
+
52
+ # Define the Gradio interface function
53
+ def classify_image(image):
54
+ return run_inference(image, net_feat, net_cls)
55
+
56
+ # Set up Gradio interface
57
  interface = gr.Interface(
58
  fn=classify_image,
59
+ inputs=gr.Image(shape=(224, 224), label="Upload an Image"),
60
+ outputs=gr.Textbox(label="Predicted Clothing1M Class"),
61
  title="Clothing1M Classifier",
62
+ description="Upload an image of clothing to classify it into one of 14 categories."
 
 
63
  )
64
 
65
+ # Launch the application
66
  if __name__ == "__main__":
67
  interface.launch()