tfarhan10 commited on
Commit
35bc32f
·
verified ·
1 Parent(s): 32eaa90

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import gradio as gr
3
+ # import torchvision.transforms as transforms
4
+ # from PIL import Image
5
+ # from huggingface_hub import hf_hub_download
6
+ # from C2D.models.resnet import SupCEResNet
7
+
8
+ # # Define class labels
9
+ # class_labels = [
10
+ # "T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater",
11
+ # "Hoodie", "Windbreaker", "Jacket", "Down Coat", "Suit",
12
+ # "Shawl", "Dress", "Vest", "Underwear"
13
+ # ]
14
+
15
+ # # Load model from Hugging Face Hub
16
+ # def load_model_from_huggingface(repo_id="tfarhan10/Clothing1M-Pretrained-ResNet50", filename="model.pth"):
17
+ # try:
18
+ # print("Downloading model from Hugging Face...")
19
+ # checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
20
+
21
+ # # Load checkpoint
22
+ # checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'),weights_only=False)
23
+
24
+ # # Extract state_dict if stored in a dictionary
25
+ # if isinstance(checkpoint, dict) and "model" in checkpoint:
26
+ # state_dict = checkpoint["model"]
27
+ # else:
28
+ # state_dict = checkpoint
29
+
30
+ # # Fix "module." prefix issue
31
+ # new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
32
+
33
+ # # Initialize model
34
+ # model = SupCEResNet(name='resnet50', num_classes=14, pool=True)
35
+
36
+ # # Load weights
37
+ # model.load_state_dict(new_state_dict, strict=False) # `strict=False` allows minor mismatches
38
+ # model.eval() # Set model to evaluation mode
39
+
40
+ # print("✅ Model loaded successfully from Hugging Face!")
41
+ # return model
42
+
43
+ # except Exception as e:
44
+ # print(f"❌ Error loading model: {e}")
45
+ # return None
46
+
47
+
48
+ # # Load the model
49
+ # model = load_model_from_huggingface()
50
+
51
+ # def classify_image(image):
52
+ # """Process and classify an uploaded PIL image accurately."""
53
+
54
+ # # Convert image to RGB to avoid grayscale or RGBA issues
55
+ # if image.mode != "RGB":
56
+ # image = image.convert("RGB")
57
+
58
+ # # Define the same preprocessing pipeline as training
59
+ # transform_test = transforms.Compose([
60
+ # transforms.Resize(256), # Resize the shorter side to 256
61
+ # transforms.CenterCrop(224), # Center crop to 224x224 (expected input size)
62
+ # transforms.ToTensor(), # Convert to Tensor
63
+ # transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # Normalize
64
+ # ])
65
+
66
+ # # Apply transformations
67
+ # image_tensor = transform_test(image).unsqueeze(0) # Add batch dimension
68
+
69
+ # # Ensure tensor is on the same device as model
70
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ # model.to(device)
72
+ # image_tensor = image_tensor.to(device)
73
+
74
+ # # Run inference
75
+ # with torch.no_grad():
76
+ # output = model(image_tensor)
77
+ # _, pred = torch.max(output, 1) # Get predicted class index
78
+
79
+ # # Map predicted class index to label
80
+ # predicted_label = class_labels[pred.item()]
81
+ # print(pred.item())
82
+ # return f"Predicted Category: {predicted_label}"
83
+
84
+
85
+ # # Create Gradio Interface
86
+ # example = "https://huggingface.co/tfarhan10/Clothing1M-Pretrained-ResNet50/blob/main/content/drive/MyDrive/CS5930/download.jpeg"
87
+ # interface = gr.Interface(
88
+ # fn=classify_image,
89
+ # inputs=gr.Image(type="pil"), # Accept image input
90
+ # outputs="text",
91
+ # title="Clothing Image Classifier",
92
+ # description="Upload an image and the model will classify it into one of 14 clothing categories.",
93
+ # allow_flagging="never", # Disable flagging feature
94
+ # examples = [[example]]
95
+ # )
96
+
97
+ # # Launch the app
98
+ # if __name__ == "__main__":
99
+ # interface.launch()
100
+
101
+
102
+ import torch
103
+ import gradio as gr
104
+ import torchvision.transforms as transforms
105
+ from PIL import Image
106
+ from huggingface_hub import hf_hub_download
107
+ import requests
108
+ from io import BytesIO
109
+ from C2D.models.resnet import SupCEResNet
110
+
111
+ # Define class labels
112
+ class_labels = [
113
+ "T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater",
114
+ "Hoodie", "Windbreaker", "Jacket", "Down Coat", "Suit",
115
+ "Shawl", "Dress", "Vest", "Underwear"
116
+ ]
117
+
118
+ # Load model from Hugging Face Hub
119
+ def load_model_from_huggingface(repo_id="tfarhan10/Clothing1M-Pretrained-ResNet50", filename="model.pth"):
120
+ try:
121
+ print("Downloading model from Hugging Face...")
122
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
123
+
124
+ # Load checkpoint
125
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'),weights_only=False)
126
+
127
+ # Extract state_dict if stored in a dictionary
128
+ if isinstance(checkpoint, dict) and "model" in checkpoint:
129
+ state_dict = checkpoint["model"]
130
+ else:
131
+ state_dict = checkpoint
132
+
133
+ # Fix "module." prefix issue
134
+ new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
135
+
136
+ # Initialize model
137
+ model = SupCEResNet(name='resnet50', num_classes=14, pool=True)
138
+
139
+ # Load weights
140
+ model.load_state_dict(new_state_dict, strict=False) # `strict=False` allows minor mismatches
141
+ model.eval() # Set model to evaluation mode
142
+
143
+ print("✅ Model loaded successfully from Hugging Face!")
144
+ return model
145
+
146
+ except Exception as e:
147
+ print(f"❌ Error loading model: {e}")
148
+ return None
149
+
150
+ # Load the model
151
+ model = load_model_from_huggingface()
152
+
153
+ def classify_image(image):
154
+ """Process and classify an uploaded PIL image accurately."""
155
+
156
+ # Ensure image is in RGB format
157
+ if image.mode != "RGB":
158
+ image = image.convert("RGB")
159
+
160
+ # Define preprocessing transformations (same as training)
161
+ transform_test = transforms.Compose([
162
+ transforms.Resize(256), # Resize the shorter side to 256
163
+ transforms.CenterCrop(224), # Center crop to 224x224 (expected input size)
164
+ transforms.ToTensor(), # Convert to Tensor
165
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # Normalize
166
+ ])
167
+
168
+ # Apply transformations
169
+ image_tensor = transform_test(image).unsqueeze(0) # Add batch dimension
170
+
171
+ # Ensure tensor is on the same device as model
172
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
173
+ model.to(device)
174
+ image_tensor = image_tensor.to(device)
175
+
176
+ # Run inference
177
+ with torch.no_grad():
178
+ output = model(image_tensor)
179
+ _, pred = torch.max(output, 1) # Get predicted class index
180
+
181
+ # Map predicted class index to label
182
+ predicted_label = class_labels[pred.item()]
183
+ return f"Predicted Category: {predicted_label}"
184
+
185
+ # Load example image from Hugging Face repository
186
+ example_url = "https://huggingface.co/tfarhan10/Clothing1M-Pretrained-ResNet50/resolve/main/content/drive/MyDrive/CS5930/download.jpeg"
187
+
188
+ def load_example_image():
189
+ """Download and return an example image from Hugging Face"""
190
+ try:
191
+ response = requests.get(example_url)
192
+ if response.status_code == 200:
193
+ return Image.open(BytesIO(response.content)).convert("RGB")
194
+ else:
195
+ print("⚠️ Failed to fetch example image.")
196
+ return None
197
+ except Exception as e:
198
+ print(f"⚠️ Error loading example image: {e}")
199
+ return None
200
+
201
+ # Example image
202
+ example_image = load_example_image()
203
+
204
+ # Create Gradio Interface
205
+ interface = gr.Interface(
206
+ fn=classify_image,
207
+ inputs=gr.Image(type="pil"), # Accept image input
208
+ outputs="text",
209
+ title="Clothing Image Classifier",
210
+ description="Upload an image or use the example below. The model will classify it into one of 14 clothing categories.",
211
+ allow_flagging="never", # Disable flagging feature
212
+ examples=[[example_image]] if example_image else None # Use example image if available
213
+ )
214
+
215
+ # Launch the app
216
+ if __name__ == "__main__":
217
+ interface.launch()