JohnJoelMota commited on
Commit
e245366
·
verified ·
1 Parent(s): c1ed2c2

Created App.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from ResNet_for_CC import CC_model
6
+
7
+ # Initialize the model
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+ model = CC_model()
10
+
11
+ # Load the pre-trained weights, adjusting for DataParallel if necessary
12
+ model_path = 'CC_net.pt'
13
+ checkpoint = torch.load(model_path, map_location=device)
14
+ if any(key.startswith('module.') for key in checkpoint.keys()):
15
+ checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}
16
+ model.load_state_dict(checkpoint)
17
+ model.eval()
18
+ model.to(device)
19
+
20
+ # Image preprocessing
21
+ preprocess = transforms.Compose([
22
+ transforms.Resize(256),
23
+ transforms.CenterCrop(224),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+ ])
27
+
28
+ # Define class names from category_names_eng.txt
29
+ class_names = [
30
+ 'T-Shirt', 'Shirt', 'Knitwear', 'Chiffon', 'Sweater', 'Hoodie',
31
+ 'Windbreaker', 'Jacket', 'Downcoat', 'Suit', 'Shawl', 'Dress',
32
+ 'Vest', 'Underwear'
33
+ ]
34
+
35
+ def predict(image):
36
+ # Convert Gradio Image to PIL and preprocess
37
+ img = Image.fromarray(image.astype('uint8'), 'RGB')
38
+ img = preprocess(img).unsqueeze(0).to(device)
39
+
40
+ # Generate predictions
41
+ with torch.no_grad():
42
+ dr_feature, output_mean = model(img)
43
+
44
+ # Get the predicted class
45
+ _, predicted = torch.max(output_mean, 1)
46
+ predicted_class = class_names[predicted.item()]
47
+
48
+ # Format output
49
+ return f"Predicted class: {predicted_class} (class number: {predicted.item()})"
50
+
51
+ # Example images from Hugging Face
52
+ examples = [
53
+ ["example_image(1).JPG"],
54
+ ["example_image(2).jpg"],
55
+ ["example_image(3).jpg"],
56
+ ["example_image(4).webp"],
57
+ ["example_image(5).webp"],
58
+ ["example_image(6).webp"]
59
+ ]
60
+
61
+ # Gradio Interface
62
+ interface = gr.Interface(
63
+ fn=predict,
64
+ inputs=gr.Image(label="Upload Clothing Image"),
65
+ outputs=gr.Textbox(label="Prediction"),
66
+ title="Clothing1M Class Predictor",
67
+ description="This model predicts the class of clothing images using CC_net.",
68
+ examples=examples
69
+ )
70
+
71
+ interface.launch()