Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| import open_clip | |
| from pathlib import Path | |
| import json | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| # Load category mapping from JSON file | |
| def load_category_mapping(): | |
| with open("cat_attr_map.json", "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| CATEGORY_MAPPING = load_category_mapping() | |
| class CategoryAwareAttributePredictor(nn.Module): | |
| def __init__( | |
| self, | |
| clip_dim=512, | |
| category_attributes=None, | |
| attribute_dims=None, | |
| hidden_dim=512, | |
| dropout_rate=0.2, | |
| num_hidden_layers=1, | |
| ): | |
| super(CategoryAwareAttributePredictor, self).__init__() | |
| self.category_attributes = category_attributes | |
| # Create prediction heads for each category-attribute combination | |
| self.attribute_predictors = nn.ModuleDict() | |
| for category, attributes in category_attributes.items(): | |
| for attr_name in attributes.keys(): | |
| key = f"{category}_{attr_name}" | |
| if key in attribute_dims: | |
| layers = [] | |
| # Input layer | |
| layers.append(nn.Linear(clip_dim, hidden_dim)) | |
| layers.append(nn.LayerNorm(hidden_dim)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(dropout_rate)) | |
| # Additional hidden layers | |
| for _ in range(num_hidden_layers - 1): | |
| layers.append(nn.Linear(hidden_dim, hidden_dim // 2)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(dropout_rate)) | |
| hidden_dim = hidden_dim // 2 | |
| # Output layer | |
| layers.append(nn.Linear(hidden_dim, attribute_dims[key])) | |
| self.attribute_predictors[key] = nn.Sequential(*layers) | |
| def forward(self, clip_features, category): | |
| results = {} | |
| category_attrs = self.category_attributes[category] | |
| clip_features = clip_features.float() | |
| for attr_name in category_attrs.keys(): | |
| key = f"{category}_{attr_name}" | |
| if key in self.attribute_predictors: | |
| results[key] = self.attribute_predictors[key](clip_features) | |
| return results | |
| class SingleImageInference: | |
| def __init__(self, model_path_gelu, model_path_convnext, device="cuda", cache_dir=None): | |
| self.device = device | |
| # Load models | |
| ( | |
| self.model_gelu, | |
| self.clip_model_gelu, | |
| self.clip_preprocess_gelu, | |
| self.checkpoint_gelu, | |
| self.model_convnext, | |
| self.clip_model_convnext, | |
| self.clip_preprocess_convnext, | |
| self.checkpoint_convnext, | |
| ) = self.load_models(model_path_gelu, model_path_convnext, self.device, cache_dir) | |
| def clean_state_dict(self, state_dict): | |
| """Clean checkpoint state dict.""" | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| name = k.replace("_orig_mod.", "") | |
| new_state_dict[name] = v | |
| return new_state_dict | |
| def create_clip_model_convnext(self, device, cache_dir=None): | |
| model, preprocess_train, _ = open_clip.create_model_and_transforms( | |
| "convnext_xxlarge", | |
| device=device, | |
| pretrained="laion2b_s34b_b82k_augreg_soup", | |
| precision="fp32", | |
| cache_dir=cache_dir, | |
| ) | |
| model = model.float() | |
| return model, preprocess_train | |
| def create_clip_model_gelu(self, device, cache_dir=None): | |
| model, preprocess_train, _ = open_clip.create_model_and_transforms( | |
| "ViT-H-14-quickgelu", | |
| device=device, | |
| pretrained="dfn5b", | |
| precision="fp32", # Explicitly set precision to fp32 | |
| cache_dir=cache_dir, | |
| ) | |
| model = model.float() | |
| return model, preprocess_train | |
| def load_models(self, model_path_gelu, model_path_convnext, device, cache_dir=None): | |
| # Load the CLIP model gelu | |
| checkpoint_gelu = torch.load(model_path_gelu, map_location="cpu",weights_only = False) | |
| clean_clip_checkpoint_gelu = self.clean_state_dict( | |
| checkpoint_gelu["clip_model_state_dict"] | |
| ) | |
| clip_model_gelu, clip_preprocess_gelu = self.create_clip_model_gelu("cpu", cache_dir) | |
| clip_model_gelu.load_state_dict(clean_clip_checkpoint_gelu) | |
| clip_model_gelu = clip_model_gelu.to(device) | |
| del clean_clip_checkpoint_gelu | |
| torch.cuda.empty_cache() | |
| # Load the CLIP model convnext | |
| checkpoint_convnext = torch.load(model_path_convnext, map_location="cpu",weights_only = False) | |
| clean_clip_checkpoint_convnext = self.clean_state_dict( | |
| checkpoint_convnext["clip_model_state_dict"] | |
| ) | |
| clip_model_convnext, clip_preprocess_convnext = self.create_clip_model_convnext( | |
| "cpu", cache_dir | |
| ) | |
| clip_model_convnext.load_state_dict(clean_clip_checkpoint_convnext) | |
| clip_model_convnext = clip_model_convnext.to(device) | |
| del clean_clip_checkpoint_convnext | |
| torch.cuda.empty_cache() | |
| # Load the attribute predictor models | |
| model_gelu = CategoryAwareAttributePredictor( | |
| clip_dim=checkpoint_gelu["model_config"]["clip_dim"], | |
| category_attributes=checkpoint_gelu["dataset_info"]["category_mapping"], | |
| attribute_dims={ | |
| key: len(values) | |
| for key, values in checkpoint_gelu["dataset_info"][ | |
| "attribute_classes" | |
| ].items() | |
| }, | |
| hidden_dim=checkpoint_gelu["model_config"]["hidden_dim"], | |
| dropout_rate=checkpoint_gelu["model_config"]["dropout_rate"], | |
| num_hidden_layers=checkpoint_gelu["model_config"]["num_hidden_layers"], | |
| ).to(device) | |
| model_convnext = CategoryAwareAttributePredictor( | |
| clip_dim=checkpoint_convnext["model_config"]["clip_dim"], | |
| category_attributes=checkpoint_convnext["dataset_info"]["category_mapping"], | |
| attribute_dims={ | |
| key: len(values) | |
| for key, values in checkpoint_convnext["dataset_info"][ | |
| "attribute_classes" | |
| ].items() | |
| }, | |
| hidden_dim=checkpoint_convnext["model_config"]["hidden_dim"], | |
| dropout_rate=checkpoint_convnext["model_config"]["dropout_rate"], | |
| num_hidden_layers=checkpoint_convnext["model_config"]["num_hidden_layers"], | |
| ).to(device) | |
| clean_cat_checkpoint_gelu = self.clean_state_dict(checkpoint_gelu["model_state_dict"]) | |
| model_gelu.load_state_dict(clean_cat_checkpoint_gelu) | |
| del clean_cat_checkpoint_gelu | |
| clean_cat_checkpoint_convnext = self.clean_state_dict( | |
| checkpoint_convnext["model_state_dict"] | |
| ) | |
| model_convnext.load_state_dict(clean_cat_checkpoint_convnext) | |
| del clean_cat_checkpoint_convnext | |
| if hasattr(torch, "compile"): | |
| model_gelu = torch.compile(model_gelu) | |
| clip_model_gelu = torch.compile(clip_model_gelu) | |
| model_convnext = torch.compile(model_convnext) | |
| clip_model_convnext = torch.compile(clip_model_convnext) | |
| model_gelu.eval() | |
| clip_model_gelu.eval() | |
| model_convnext.eval() | |
| clip_model_convnext.eval() | |
| return ( | |
| model_gelu, | |
| clip_model_gelu, | |
| clip_preprocess_gelu, | |
| checkpoint_gelu["dataset_info"], | |
| model_convnext, | |
| clip_model_convnext, | |
| clip_preprocess_convnext, | |
| checkpoint_convnext["dataset_info"], | |
| ) | |
| def predict_single_image(self, image_path, category): | |
| """Perform inference on a single image.""" | |
| if not Path(image_path).exists(): | |
| raise FileNotFoundError(f"Image {image_path} does not exist!") | |
| # Preprocess image | |
| image = Image.open(image_path).convert("RGB") | |
| image_gelu = self.clip_preprocess_gelu(image).unsqueeze(0).to(self.device) | |
| image_convnext = self.clip_preprocess_convnext(image).unsqueeze(0).to(self.device) | |
| # Extract CLIP features | |
| with torch.no_grad(): | |
| clip_features_gelu = self.clip_model_gelu.encode_image(image_gelu).float() | |
| clip_features_convnext = self.clip_model_convnext.encode_image(image_convnext).float() | |
| # Predict attributes | |
| predictions_gelu = self.model_gelu(clip_features_gelu, category) | |
| predictions_convnext = self.model_convnext(clip_features_convnext, category) | |
| # Ensemble predictions | |
| ensemble_predictions = {} | |
| for key, pred_gelu in predictions_gelu.items(): | |
| pred_convnext = predictions_convnext[key].to(self.device) | |
| ensemble_predictions[key] = 0.5 * pred_gelu + 0.5 * pred_convnext | |
| # Convert predictions to attributes | |
| predicted_attributes = {} | |
| for key, pred in ensemble_predictions.items(): | |
| _, predicted_idx = torch.max(pred, 1) | |
| predicted_idx = predicted_idx.item() | |
| attr_name = key.split("_", 1)[1] | |
| attr_values = self.checkpoint_gelu["attribute_classes"][key] | |
| if predicted_idx < len(attr_values): | |
| predicted_attributes[attr_name] = attr_values[predicted_idx] | |
| return predicted_attributes | |
| # Function to make predictions using the provided image and category | |
| def predict_attributes(image, category): | |
| try: | |
| # Save the uploaded image temporarily for processing | |
| image_path = "temp_image.jpg" | |
| image.save(image_path) | |
| # Call the inference method | |
| predictions = inference.predict_single_image(image_path, category) | |
| # Format predictions as a markdown table | |
| markdown_output = "### Predicted Attributes\n\n| Attribute | Value |\n|-----------|-------|\n" | |
| for attr, value in predictions.items(): | |
| markdown_output += f"| {attr} | {value} |\n" | |
| return markdown_output | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Define Gradio interface | |
| def gradio_interface(): | |
| # Define input components | |
| image_input = gr.Image(label="Upload an Image", type="pil") | |
| category_input = gr.Dropdown(label="Choose Category", choices=['Men Tshirts', 'Women Tshirts', 'Sarees', 'Kurtis', 'Women Tops & Tunics']) | |
| # category_input = gr.Textbox(label="Enter Category", placeholder="e.g., shoes, clothes") | |
| # Define output | |
| output = gr.Markdown(label="Predicted Attributes") | |
| # Create Gradio interface | |
| interface = gr.Interface( | |
| fn=predict_attributes, | |
| inputs=[image_input, category_input], | |
| outputs=output, | |
| title="Attribute Prediction", | |
| description="Upload an image and specify its category to get the predicted attributes.", | |
| theme="default", | |
| flagging_mode="never" | |
| ) | |
| return interface | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_path_gelu = "vith14_gelu_highest_f1.pth" | |
| model_path_convnext = "Final_clip_convnext_xxlarge_laion3_4_train_032301.pth" | |
| inference = SingleImageInference( | |
| model_path_gelu=model_path_gelu, | |
| model_path_convnext=model_path_convnext, | |
| device=device | |
| ) | |
| gradio_interface().launch() | |