nachi1326's picture
Create app.py
bf2a6a5 verified
import torch
import torch.nn as nn
from PIL import Image
import open_clip
from pathlib import Path
import json
import torch
import gradio as gr
from PIL import Image
# Load category mapping from JSON file
def load_category_mapping():
with open("cat_attr_map.json", "r", encoding="utf-8") as f:
return json.load(f)
CATEGORY_MAPPING = load_category_mapping()
class CategoryAwareAttributePredictor(nn.Module):
def __init__(
self,
clip_dim=512,
category_attributes=None,
attribute_dims=None,
hidden_dim=512,
dropout_rate=0.2,
num_hidden_layers=1,
):
super(CategoryAwareAttributePredictor, self).__init__()
self.category_attributes = category_attributes
# Create prediction heads for each category-attribute combination
self.attribute_predictors = nn.ModuleDict()
for category, attributes in category_attributes.items():
for attr_name in attributes.keys():
key = f"{category}_{attr_name}"
if key in attribute_dims:
layers = []
# Input layer
layers.append(nn.Linear(clip_dim, hidden_dim))
layers.append(nn.LayerNorm(hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
# Additional hidden layers
for _ in range(num_hidden_layers - 1):
layers.append(nn.Linear(hidden_dim, hidden_dim // 2))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
hidden_dim = hidden_dim // 2
# Output layer
layers.append(nn.Linear(hidden_dim, attribute_dims[key]))
self.attribute_predictors[key] = nn.Sequential(*layers)
def forward(self, clip_features, category):
results = {}
category_attrs = self.category_attributes[category]
clip_features = clip_features.float()
for attr_name in category_attrs.keys():
key = f"{category}_{attr_name}"
if key in self.attribute_predictors:
results[key] = self.attribute_predictors[key](clip_features)
return results
class SingleImageInference:
def __init__(self, model_path_gelu, model_path_convnext, device="cuda", cache_dir=None):
self.device = device
# Load models
(
self.model_gelu,
self.clip_model_gelu,
self.clip_preprocess_gelu,
self.checkpoint_gelu,
self.model_convnext,
self.clip_model_convnext,
self.clip_preprocess_convnext,
self.checkpoint_convnext,
) = self.load_models(model_path_gelu, model_path_convnext, self.device, cache_dir)
def clean_state_dict(self, state_dict):
"""Clean checkpoint state dict."""
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("_orig_mod.", "")
new_state_dict[name] = v
return new_state_dict
def create_clip_model_convnext(self, device, cache_dir=None):
model, preprocess_train, _ = open_clip.create_model_and_transforms(
"convnext_xxlarge",
device=device,
pretrained="laion2b_s34b_b82k_augreg_soup",
precision="fp32",
cache_dir=cache_dir,
)
model = model.float()
return model, preprocess_train
def create_clip_model_gelu(self, device, cache_dir=None):
model, preprocess_train, _ = open_clip.create_model_and_transforms(
"ViT-H-14-quickgelu",
device=device,
pretrained="dfn5b",
precision="fp32", # Explicitly set precision to fp32
cache_dir=cache_dir,
)
model = model.float()
return model, preprocess_train
def load_models(self, model_path_gelu, model_path_convnext, device, cache_dir=None):
# Load the CLIP model gelu
checkpoint_gelu = torch.load(model_path_gelu, map_location="cpu",weights_only = False)
clean_clip_checkpoint_gelu = self.clean_state_dict(
checkpoint_gelu["clip_model_state_dict"]
)
clip_model_gelu, clip_preprocess_gelu = self.create_clip_model_gelu("cpu", cache_dir)
clip_model_gelu.load_state_dict(clean_clip_checkpoint_gelu)
clip_model_gelu = clip_model_gelu.to(device)
del clean_clip_checkpoint_gelu
torch.cuda.empty_cache()
# Load the CLIP model convnext
checkpoint_convnext = torch.load(model_path_convnext, map_location="cpu",weights_only = False)
clean_clip_checkpoint_convnext = self.clean_state_dict(
checkpoint_convnext["clip_model_state_dict"]
)
clip_model_convnext, clip_preprocess_convnext = self.create_clip_model_convnext(
"cpu", cache_dir
)
clip_model_convnext.load_state_dict(clean_clip_checkpoint_convnext)
clip_model_convnext = clip_model_convnext.to(device)
del clean_clip_checkpoint_convnext
torch.cuda.empty_cache()
# Load the attribute predictor models
model_gelu = CategoryAwareAttributePredictor(
clip_dim=checkpoint_gelu["model_config"]["clip_dim"],
category_attributes=checkpoint_gelu["dataset_info"]["category_mapping"],
attribute_dims={
key: len(values)
for key, values in checkpoint_gelu["dataset_info"][
"attribute_classes"
].items()
},
hidden_dim=checkpoint_gelu["model_config"]["hidden_dim"],
dropout_rate=checkpoint_gelu["model_config"]["dropout_rate"],
num_hidden_layers=checkpoint_gelu["model_config"]["num_hidden_layers"],
).to(device)
model_convnext = CategoryAwareAttributePredictor(
clip_dim=checkpoint_convnext["model_config"]["clip_dim"],
category_attributes=checkpoint_convnext["dataset_info"]["category_mapping"],
attribute_dims={
key: len(values)
for key, values in checkpoint_convnext["dataset_info"][
"attribute_classes"
].items()
},
hidden_dim=checkpoint_convnext["model_config"]["hidden_dim"],
dropout_rate=checkpoint_convnext["model_config"]["dropout_rate"],
num_hidden_layers=checkpoint_convnext["model_config"]["num_hidden_layers"],
).to(device)
clean_cat_checkpoint_gelu = self.clean_state_dict(checkpoint_gelu["model_state_dict"])
model_gelu.load_state_dict(clean_cat_checkpoint_gelu)
del clean_cat_checkpoint_gelu
clean_cat_checkpoint_convnext = self.clean_state_dict(
checkpoint_convnext["model_state_dict"]
)
model_convnext.load_state_dict(clean_cat_checkpoint_convnext)
del clean_cat_checkpoint_convnext
if hasattr(torch, "compile"):
model_gelu = torch.compile(model_gelu)
clip_model_gelu = torch.compile(clip_model_gelu)
model_convnext = torch.compile(model_convnext)
clip_model_convnext = torch.compile(clip_model_convnext)
model_gelu.eval()
clip_model_gelu.eval()
model_convnext.eval()
clip_model_convnext.eval()
return (
model_gelu,
clip_model_gelu,
clip_preprocess_gelu,
checkpoint_gelu["dataset_info"],
model_convnext,
clip_model_convnext,
clip_preprocess_convnext,
checkpoint_convnext["dataset_info"],
)
def predict_single_image(self, image_path, category):
"""Perform inference on a single image."""
if not Path(image_path).exists():
raise FileNotFoundError(f"Image {image_path} does not exist!")
# Preprocess image
image = Image.open(image_path).convert("RGB")
image_gelu = self.clip_preprocess_gelu(image).unsqueeze(0).to(self.device)
image_convnext = self.clip_preprocess_convnext(image).unsqueeze(0).to(self.device)
# Extract CLIP features
with torch.no_grad():
clip_features_gelu = self.clip_model_gelu.encode_image(image_gelu).float()
clip_features_convnext = self.clip_model_convnext.encode_image(image_convnext).float()
# Predict attributes
predictions_gelu = self.model_gelu(clip_features_gelu, category)
predictions_convnext = self.model_convnext(clip_features_convnext, category)
# Ensemble predictions
ensemble_predictions = {}
for key, pred_gelu in predictions_gelu.items():
pred_convnext = predictions_convnext[key].to(self.device)
ensemble_predictions[key] = 0.5 * pred_gelu + 0.5 * pred_convnext
# Convert predictions to attributes
predicted_attributes = {}
for key, pred in ensemble_predictions.items():
_, predicted_idx = torch.max(pred, 1)
predicted_idx = predicted_idx.item()
attr_name = key.split("_", 1)[1]
attr_values = self.checkpoint_gelu["attribute_classes"][key]
if predicted_idx < len(attr_values):
predicted_attributes[attr_name] = attr_values[predicted_idx]
return predicted_attributes
# Function to make predictions using the provided image and category
def predict_attributes(image, category):
try:
# Save the uploaded image temporarily for processing
image_path = "temp_image.jpg"
image.save(image_path)
# Call the inference method
predictions = inference.predict_single_image(image_path, category)
# Format predictions as a markdown table
markdown_output = "### Predicted Attributes\n\n| Attribute | Value |\n|-----------|-------|\n"
for attr, value in predictions.items():
markdown_output += f"| {attr} | {value} |\n"
return markdown_output
except Exception as e:
return {"error": str(e)}
# Define Gradio interface
def gradio_interface():
# Define input components
image_input = gr.Image(label="Upload an Image", type="pil")
category_input = gr.Dropdown(label="Choose Category", choices=['Men Tshirts', 'Women Tshirts', 'Sarees', 'Kurtis', 'Women Tops & Tunics'])
# category_input = gr.Textbox(label="Enter Category", placeholder="e.g., shoes, clothes")
# Define output
output = gr.Markdown(label="Predicted Attributes")
# Create Gradio interface
interface = gr.Interface(
fn=predict_attributes,
inputs=[image_input, category_input],
outputs=output,
title="Attribute Prediction",
description="Upload an image and specify its category to get the predicted attributes.",
theme="default",
flagging_mode="never"
)
return interface
# Launch the Gradio app
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path_gelu = "vith14_gelu_highest_f1.pth"
model_path_convnext = "Final_clip_convnext_xxlarge_laion3_4_train_032301.pth"
inference = SingleImageInference(
model_path_gelu=model_path_gelu,
model_path_convnext=model_path_convnext,
device=device
)
gradio_interface().launch()