nachi1326 commited on
Commit
bf2a6a5
·
verified ·
1 Parent(s): 9dafe60

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from PIL import Image
4
+ import open_clip
5
+ from pathlib import Path
6
+ import json
7
+ import torch
8
+ import gradio as gr
9
+ from PIL import Image
10
+
11
+ # Load category mapping from JSON file
12
+ def load_category_mapping():
13
+ with open("cat_attr_map.json", "r", encoding="utf-8") as f:
14
+ return json.load(f)
15
+
16
+ CATEGORY_MAPPING = load_category_mapping()
17
+
18
+ class CategoryAwareAttributePredictor(nn.Module):
19
+ def __init__(
20
+ self,
21
+ clip_dim=512,
22
+ category_attributes=None,
23
+ attribute_dims=None,
24
+ hidden_dim=512,
25
+ dropout_rate=0.2,
26
+ num_hidden_layers=1,
27
+ ):
28
+ super(CategoryAwareAttributePredictor, self).__init__()
29
+
30
+ self.category_attributes = category_attributes
31
+
32
+ # Create prediction heads for each category-attribute combination
33
+ self.attribute_predictors = nn.ModuleDict()
34
+
35
+ for category, attributes in category_attributes.items():
36
+ for attr_name in attributes.keys():
37
+ key = f"{category}_{attr_name}"
38
+ if key in attribute_dims:
39
+ layers = []
40
+
41
+ # Input layer
42
+ layers.append(nn.Linear(clip_dim, hidden_dim))
43
+ layers.append(nn.LayerNorm(hidden_dim))
44
+ layers.append(nn.ReLU())
45
+ layers.append(nn.Dropout(dropout_rate))
46
+
47
+ # Additional hidden layers
48
+ for _ in range(num_hidden_layers - 1):
49
+ layers.append(nn.Linear(hidden_dim, hidden_dim // 2))
50
+ layers.append(nn.ReLU())
51
+ layers.append(nn.Dropout(dropout_rate))
52
+
53
+ hidden_dim = hidden_dim // 2
54
+
55
+ # Output layer
56
+ layers.append(nn.Linear(hidden_dim, attribute_dims[key]))
57
+
58
+ self.attribute_predictors[key] = nn.Sequential(*layers)
59
+
60
+ def forward(self, clip_features, category):
61
+ results = {}
62
+ category_attrs = self.category_attributes[category]
63
+
64
+ clip_features = clip_features.float()
65
+
66
+ for attr_name in category_attrs.keys():
67
+ key = f"{category}_{attr_name}"
68
+ if key in self.attribute_predictors:
69
+ results[key] = self.attribute_predictors[key](clip_features)
70
+
71
+ return results
72
+
73
+
74
+ class SingleImageInference:
75
+ def __init__(self, model_path_gelu, model_path_convnext, device="cuda", cache_dir=None):
76
+ self.device = device
77
+
78
+ # Load models
79
+ (
80
+ self.model_gelu,
81
+ self.clip_model_gelu,
82
+ self.clip_preprocess_gelu,
83
+ self.checkpoint_gelu,
84
+ self.model_convnext,
85
+ self.clip_model_convnext,
86
+ self.clip_preprocess_convnext,
87
+ self.checkpoint_convnext,
88
+ ) = self.load_models(model_path_gelu, model_path_convnext, self.device, cache_dir)
89
+
90
+ def clean_state_dict(self, state_dict):
91
+ """Clean checkpoint state dict."""
92
+ new_state_dict = {}
93
+ for k, v in state_dict.items():
94
+ name = k.replace("_orig_mod.", "")
95
+ new_state_dict[name] = v
96
+ return new_state_dict
97
+
98
+ def create_clip_model_convnext(self, device, cache_dir=None):
99
+ model, preprocess_train, _ = open_clip.create_model_and_transforms(
100
+ "convnext_xxlarge",
101
+ device=device,
102
+ pretrained="laion2b_s34b_b82k_augreg_soup",
103
+ precision="fp32",
104
+ cache_dir=cache_dir,
105
+ )
106
+ model = model.float()
107
+ return model, preprocess_train
108
+
109
+
110
+ def create_clip_model_gelu(self, device, cache_dir=None):
111
+ model, preprocess_train, _ = open_clip.create_model_and_transforms(
112
+ "ViT-H-14-quickgelu",
113
+ device=device,
114
+ pretrained="dfn5b",
115
+ precision="fp32", # Explicitly set precision to fp32
116
+ cache_dir=cache_dir,
117
+ )
118
+ model = model.float()
119
+ return model, preprocess_train
120
+
121
+ def load_models(self, model_path_gelu, model_path_convnext, device, cache_dir=None):
122
+ # Load the CLIP model gelu
123
+ checkpoint_gelu = torch.load(model_path_gelu, map_location="cpu",weights_only = False)
124
+ clean_clip_checkpoint_gelu = self.clean_state_dict(
125
+ checkpoint_gelu["clip_model_state_dict"]
126
+ )
127
+
128
+ clip_model_gelu, clip_preprocess_gelu = self.create_clip_model_gelu("cpu", cache_dir)
129
+ clip_model_gelu.load_state_dict(clean_clip_checkpoint_gelu)
130
+ clip_model_gelu = clip_model_gelu.to(device)
131
+ del clean_clip_checkpoint_gelu
132
+ torch.cuda.empty_cache()
133
+
134
+ # Load the CLIP model convnext
135
+ checkpoint_convnext = torch.load(model_path_convnext, map_location="cpu",weights_only = False)
136
+ clean_clip_checkpoint_convnext = self.clean_state_dict(
137
+ checkpoint_convnext["clip_model_state_dict"]
138
+ )
139
+
140
+ clip_model_convnext, clip_preprocess_convnext = self.create_clip_model_convnext(
141
+ "cpu", cache_dir
142
+ )
143
+ clip_model_convnext.load_state_dict(clean_clip_checkpoint_convnext)
144
+ clip_model_convnext = clip_model_convnext.to(device)
145
+ del clean_clip_checkpoint_convnext
146
+ torch.cuda.empty_cache()
147
+
148
+ # Load the attribute predictor models
149
+ model_gelu = CategoryAwareAttributePredictor(
150
+ clip_dim=checkpoint_gelu["model_config"]["clip_dim"],
151
+ category_attributes=checkpoint_gelu["dataset_info"]["category_mapping"],
152
+ attribute_dims={
153
+ key: len(values)
154
+ for key, values in checkpoint_gelu["dataset_info"][
155
+ "attribute_classes"
156
+ ].items()
157
+ },
158
+ hidden_dim=checkpoint_gelu["model_config"]["hidden_dim"],
159
+ dropout_rate=checkpoint_gelu["model_config"]["dropout_rate"],
160
+ num_hidden_layers=checkpoint_gelu["model_config"]["num_hidden_layers"],
161
+ ).to(device)
162
+
163
+ model_convnext = CategoryAwareAttributePredictor(
164
+ clip_dim=checkpoint_convnext["model_config"]["clip_dim"],
165
+ category_attributes=checkpoint_convnext["dataset_info"]["category_mapping"],
166
+ attribute_dims={
167
+ key: len(values)
168
+ for key, values in checkpoint_convnext["dataset_info"][
169
+ "attribute_classes"
170
+ ].items()
171
+ },
172
+ hidden_dim=checkpoint_convnext["model_config"]["hidden_dim"],
173
+ dropout_rate=checkpoint_convnext["model_config"]["dropout_rate"],
174
+ num_hidden_layers=checkpoint_convnext["model_config"]["num_hidden_layers"],
175
+ ).to(device)
176
+
177
+ clean_cat_checkpoint_gelu = self.clean_state_dict(checkpoint_gelu["model_state_dict"])
178
+ model_gelu.load_state_dict(clean_cat_checkpoint_gelu)
179
+ del clean_cat_checkpoint_gelu
180
+
181
+ clean_cat_checkpoint_convnext = self.clean_state_dict(
182
+ checkpoint_convnext["model_state_dict"]
183
+ )
184
+ model_convnext.load_state_dict(clean_cat_checkpoint_convnext)
185
+ del clean_cat_checkpoint_convnext
186
+
187
+ if hasattr(torch, "compile"):
188
+ model_gelu = torch.compile(model_gelu)
189
+ clip_model_gelu = torch.compile(clip_model_gelu)
190
+ model_convnext = torch.compile(model_convnext)
191
+ clip_model_convnext = torch.compile(clip_model_convnext)
192
+
193
+ model_gelu.eval()
194
+ clip_model_gelu.eval()
195
+ model_convnext.eval()
196
+ clip_model_convnext.eval()
197
+
198
+ return (
199
+ model_gelu,
200
+ clip_model_gelu,
201
+ clip_preprocess_gelu,
202
+ checkpoint_gelu["dataset_info"],
203
+ model_convnext,
204
+ clip_model_convnext,
205
+ clip_preprocess_convnext,
206
+ checkpoint_convnext["dataset_info"],
207
+ )
208
+
209
+ def predict_single_image(self, image_path, category):
210
+ """Perform inference on a single image."""
211
+ if not Path(image_path).exists():
212
+ raise FileNotFoundError(f"Image {image_path} does not exist!")
213
+
214
+ # Preprocess image
215
+ image = Image.open(image_path).convert("RGB")
216
+ image_gelu = self.clip_preprocess_gelu(image).unsqueeze(0).to(self.device)
217
+ image_convnext = self.clip_preprocess_convnext(image).unsqueeze(0).to(self.device)
218
+
219
+ # Extract CLIP features
220
+ with torch.no_grad():
221
+ clip_features_gelu = self.clip_model_gelu.encode_image(image_gelu).float()
222
+ clip_features_convnext = self.clip_model_convnext.encode_image(image_convnext).float()
223
+
224
+ # Predict attributes
225
+ predictions_gelu = self.model_gelu(clip_features_gelu, category)
226
+ predictions_convnext = self.model_convnext(clip_features_convnext, category)
227
+
228
+ # Ensemble predictions
229
+ ensemble_predictions = {}
230
+ for key, pred_gelu in predictions_gelu.items():
231
+ pred_convnext = predictions_convnext[key].to(self.device)
232
+ ensemble_predictions[key] = 0.5 * pred_gelu + 0.5 * pred_convnext
233
+
234
+ # Convert predictions to attributes
235
+ predicted_attributes = {}
236
+ for key, pred in ensemble_predictions.items():
237
+ _, predicted_idx = torch.max(pred, 1)
238
+ predicted_idx = predicted_idx.item()
239
+
240
+ attr_name = key.split("_", 1)[1]
241
+ attr_values = self.checkpoint_gelu["attribute_classes"][key]
242
+ if predicted_idx < len(attr_values):
243
+ predicted_attributes[attr_name] = attr_values[predicted_idx]
244
+
245
+ return predicted_attributes
246
+
247
+ # Function to make predictions using the provided image and category
248
+ def predict_attributes(image, category):
249
+ try:
250
+ # Save the uploaded image temporarily for processing
251
+ image_path = "temp_image.jpg"
252
+ image.save(image_path)
253
+
254
+ # Call the inference method
255
+ predictions = inference.predict_single_image(image_path, category)
256
+ # Format predictions as a markdown table
257
+ markdown_output = "### Predicted Attributes\n\n| Attribute | Value |\n|-----------|-------|\n"
258
+ for attr, value in predictions.items():
259
+ markdown_output += f"| {attr} | {value} |\n"
260
+ return markdown_output
261
+
262
+ except Exception as e:
263
+ return {"error": str(e)}
264
+
265
+ # Define Gradio interface
266
+ def gradio_interface():
267
+ # Define input components
268
+ image_input = gr.Image(label="Upload an Image", type="pil")
269
+ category_input = gr.Dropdown(label="Choose Category", choices=['Men Tshirts', 'Women Tshirts', 'Sarees', 'Kurtis', 'Women Tops & Tunics'])
270
+ # category_input = gr.Textbox(label="Enter Category", placeholder="e.g., shoes, clothes")
271
+
272
+ # Define output
273
+ output = gr.Markdown(label="Predicted Attributes")
274
+
275
+ # Create Gradio interface
276
+ interface = gr.Interface(
277
+ fn=predict_attributes,
278
+ inputs=[image_input, category_input],
279
+ outputs=output,
280
+ title="Attribute Prediction",
281
+ description="Upload an image and specify its category to get the predicted attributes.",
282
+ theme="default",
283
+ flagging_mode="never"
284
+ )
285
+
286
+ return interface
287
+
288
+ # Launch the Gradio app
289
+ if __name__ == "__main__":
290
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
291
+ model_path_gelu = "vith14_gelu_highest_f1.pth"
292
+ model_path_convnext = "Final_clip_convnext_xxlarge_laion3_4_train_032301.pth"
293
+
294
+ inference = SingleImageInference(
295
+ model_path_gelu=model_path_gelu,
296
+ model_path_convnext=model_path_convnext,
297
+ device=device
298
+ )
299
+
300
+ gradio_interface().launch()