import torch import torch.nn as nn from PIL import Image import open_clip from pathlib import Path import json import torch import gradio as gr from PIL import Image # Load category mapping from JSON file def load_category_mapping(): with open("cat_attr_map.json", "r", encoding="utf-8") as f: return json.load(f) CATEGORY_MAPPING = load_category_mapping() class CategoryAwareAttributePredictor(nn.Module): def __init__( self, clip_dim=512, category_attributes=None, attribute_dims=None, hidden_dim=512, dropout_rate=0.2, num_hidden_layers=1, ): super(CategoryAwareAttributePredictor, self).__init__() self.category_attributes = category_attributes # Create prediction heads for each category-attribute combination self.attribute_predictors = nn.ModuleDict() for category, attributes in category_attributes.items(): for attr_name in attributes.keys(): key = f"{category}_{attr_name}" if key in attribute_dims: layers = [] # Input layer layers.append(nn.Linear(clip_dim, hidden_dim)) layers.append(nn.LayerNorm(hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout_rate)) # Additional hidden layers for _ in range(num_hidden_layers - 1): layers.append(nn.Linear(hidden_dim, hidden_dim // 2)) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout_rate)) hidden_dim = hidden_dim // 2 # Output layer layers.append(nn.Linear(hidden_dim, attribute_dims[key])) self.attribute_predictors[key] = nn.Sequential(*layers) def forward(self, clip_features, category): results = {} category_attrs = self.category_attributes[category] clip_features = clip_features.float() for attr_name in category_attrs.keys(): key = f"{category}_{attr_name}" if key in self.attribute_predictors: results[key] = self.attribute_predictors[key](clip_features) return results class SingleImageInference: def __init__(self, model_path_gelu, model_path_convnext, device="cuda", cache_dir=None): self.device = device # Load models ( self.model_gelu, self.clip_model_gelu, self.clip_preprocess_gelu, self.checkpoint_gelu, self.model_convnext, self.clip_model_convnext, self.clip_preprocess_convnext, self.checkpoint_convnext, ) = self.load_models(model_path_gelu, model_path_convnext, self.device, cache_dir) def clean_state_dict(self, state_dict): """Clean checkpoint state dict.""" new_state_dict = {} for k, v in state_dict.items(): name = k.replace("_orig_mod.", "") new_state_dict[name] = v return new_state_dict def create_clip_model_convnext(self, device, cache_dir=None): model, preprocess_train, _ = open_clip.create_model_and_transforms( "convnext_xxlarge", device=device, pretrained="laion2b_s34b_b82k_augreg_soup", precision="fp32", cache_dir=cache_dir, ) model = model.float() return model, preprocess_train def create_clip_model_gelu(self, device, cache_dir=None): model, preprocess_train, _ = open_clip.create_model_and_transforms( "ViT-H-14-quickgelu", device=device, pretrained="dfn5b", precision="fp32", # Explicitly set precision to fp32 cache_dir=cache_dir, ) model = model.float() return model, preprocess_train def load_models(self, model_path_gelu, model_path_convnext, device, cache_dir=None): # Load the CLIP model gelu checkpoint_gelu = torch.load(model_path_gelu, map_location="cpu",weights_only = False) clean_clip_checkpoint_gelu = self.clean_state_dict( checkpoint_gelu["clip_model_state_dict"] ) clip_model_gelu, clip_preprocess_gelu = self.create_clip_model_gelu("cpu", cache_dir) clip_model_gelu.load_state_dict(clean_clip_checkpoint_gelu) clip_model_gelu = clip_model_gelu.to(device) del clean_clip_checkpoint_gelu torch.cuda.empty_cache() # Load the CLIP model convnext checkpoint_convnext = torch.load(model_path_convnext, map_location="cpu",weights_only = False) clean_clip_checkpoint_convnext = self.clean_state_dict( checkpoint_convnext["clip_model_state_dict"] ) clip_model_convnext, clip_preprocess_convnext = self.create_clip_model_convnext( "cpu", cache_dir ) clip_model_convnext.load_state_dict(clean_clip_checkpoint_convnext) clip_model_convnext = clip_model_convnext.to(device) del clean_clip_checkpoint_convnext torch.cuda.empty_cache() # Load the attribute predictor models model_gelu = CategoryAwareAttributePredictor( clip_dim=checkpoint_gelu["model_config"]["clip_dim"], category_attributes=checkpoint_gelu["dataset_info"]["category_mapping"], attribute_dims={ key: len(values) for key, values in checkpoint_gelu["dataset_info"][ "attribute_classes" ].items() }, hidden_dim=checkpoint_gelu["model_config"]["hidden_dim"], dropout_rate=checkpoint_gelu["model_config"]["dropout_rate"], num_hidden_layers=checkpoint_gelu["model_config"]["num_hidden_layers"], ).to(device) model_convnext = CategoryAwareAttributePredictor( clip_dim=checkpoint_convnext["model_config"]["clip_dim"], category_attributes=checkpoint_convnext["dataset_info"]["category_mapping"], attribute_dims={ key: len(values) for key, values in checkpoint_convnext["dataset_info"][ "attribute_classes" ].items() }, hidden_dim=checkpoint_convnext["model_config"]["hidden_dim"], dropout_rate=checkpoint_convnext["model_config"]["dropout_rate"], num_hidden_layers=checkpoint_convnext["model_config"]["num_hidden_layers"], ).to(device) clean_cat_checkpoint_gelu = self.clean_state_dict(checkpoint_gelu["model_state_dict"]) model_gelu.load_state_dict(clean_cat_checkpoint_gelu) del clean_cat_checkpoint_gelu clean_cat_checkpoint_convnext = self.clean_state_dict( checkpoint_convnext["model_state_dict"] ) model_convnext.load_state_dict(clean_cat_checkpoint_convnext) del clean_cat_checkpoint_convnext if hasattr(torch, "compile"): model_gelu = torch.compile(model_gelu) clip_model_gelu = torch.compile(clip_model_gelu) model_convnext = torch.compile(model_convnext) clip_model_convnext = torch.compile(clip_model_convnext) model_gelu.eval() clip_model_gelu.eval() model_convnext.eval() clip_model_convnext.eval() return ( model_gelu, clip_model_gelu, clip_preprocess_gelu, checkpoint_gelu["dataset_info"], model_convnext, clip_model_convnext, clip_preprocess_convnext, checkpoint_convnext["dataset_info"], ) def predict_single_image(self, image_path, category): """Perform inference on a single image.""" if not Path(image_path).exists(): raise FileNotFoundError(f"Image {image_path} does not exist!") # Preprocess image image = Image.open(image_path).convert("RGB") image_gelu = self.clip_preprocess_gelu(image).unsqueeze(0).to(self.device) image_convnext = self.clip_preprocess_convnext(image).unsqueeze(0).to(self.device) # Extract CLIP features with torch.no_grad(): clip_features_gelu = self.clip_model_gelu.encode_image(image_gelu).float() clip_features_convnext = self.clip_model_convnext.encode_image(image_convnext).float() # Predict attributes predictions_gelu = self.model_gelu(clip_features_gelu, category) predictions_convnext = self.model_convnext(clip_features_convnext, category) # Ensemble predictions ensemble_predictions = {} for key, pred_gelu in predictions_gelu.items(): pred_convnext = predictions_convnext[key].to(self.device) ensemble_predictions[key] = 0.5 * pred_gelu + 0.5 * pred_convnext # Convert predictions to attributes predicted_attributes = {} for key, pred in ensemble_predictions.items(): _, predicted_idx = torch.max(pred, 1) predicted_idx = predicted_idx.item() attr_name = key.split("_", 1)[1] attr_values = self.checkpoint_gelu["attribute_classes"][key] if predicted_idx < len(attr_values): predicted_attributes[attr_name] = attr_values[predicted_idx] return predicted_attributes # Function to make predictions using the provided image and category def predict_attributes(image, category): try: # Save the uploaded image temporarily for processing image_path = "temp_image.jpg" image.save(image_path) # Call the inference method predictions = inference.predict_single_image(image_path, category) # Format predictions as a markdown table markdown_output = "### Predicted Attributes\n\n| Attribute | Value |\n|-----------|-------|\n" for attr, value in predictions.items(): markdown_output += f"| {attr} | {value} |\n" return markdown_output except Exception as e: return {"error": str(e)} # Define Gradio interface def gradio_interface(): # Define input components image_input = gr.Image(label="Upload an Image", type="pil") category_input = gr.Dropdown(label="Choose Category", choices=['Men Tshirts', 'Women Tshirts', 'Sarees', 'Kurtis', 'Women Tops & Tunics']) # category_input = gr.Textbox(label="Enter Category", placeholder="e.g., shoes, clothes") # Define output output = gr.Markdown(label="Predicted Attributes") # Create Gradio interface interface = gr.Interface( fn=predict_attributes, inputs=[image_input, category_input], outputs=output, title="Attribute Prediction", description="Upload an image and specify its category to get the predicted attributes.", theme="default", flagging_mode="never" ) return interface # Launch the Gradio app if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path_gelu = "vith14_gelu_highest_f1.pth" model_path_convnext = "Final_clip_convnext_xxlarge_laion3_4_train_032301.pth" inference = SingleImageInference( model_path_gelu=model_path_gelu, model_path_convnext=model_path_convnext, device=device ) gradio_interface().launch()