Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from ResNet_for_CC import CC_model # Import updated model | |
| # Set device (CPU/GPU) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the trained CC_model | |
| model_path = "CC_net.pt" # Ensure correct path | |
| model = CC_model(num_classes1=14) # Updated model with classification | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Define Clothing1M Class Labels | |
| class_labels = [ | |
| "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", | |
| "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", | |
| "Vest", "Underwear" | |
| ] | |
| # Define preprocessing for images | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Function for Image Classification | |
| def classify_image(image): | |
| image = transform(image).unsqueeze(0).to(device) # Preprocess image | |
| with torch.no_grad(): | |
| _, output = model(image) # Unpack to get only output_mean | |
| predicted_class = torch.argmax(output, dim=1).item() # Get class index | |
| return f"Predicted Class: {class_labels[predicted_class]}" | |
| # Create Gradio Interface | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="Clothing1M Image Classifier", | |
| description="Upload a clothing image, and the model will classify it into one of the 14 categories." | |
| ) | |
| # Run the Interface | |
| if __name__ == "__main__": | |
| interface.launch() | |