Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from ResNet_for_CC import CC_model # Import updated model | |
| # Set device (CPU/GPU) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the trained CC_model | |
| model_path = "CC_net.pt" # Ensure correct path | |
| model = CC_model(num_classes1=14) # Updated model with classification | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Define Clothing1M Class Labels | |
| class_labels = [ | |
| "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", | |
| "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", | |
| "Vest", "Underwear" | |
| ] | |
| # β **Updated Preprocessing for Images** | |
| def preprocess_image(image): | |
| """Preprocess input image before classification.""" | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(image).unsqueeze(0).to(device) | |
| # β **Image Classification Function** | |
| def classify_image(image): | |
| """Processes the input image and returns the predicted clothing category.""" | |
| image = preprocess_image(image) # Apply transformations | |
| with torch.no_grad(): | |
| output = model(image) | |
| # β Handle tuple output (if model returns multiple values) | |
| if isinstance(output, tuple): | |
| output = output[1] | |
| predicted_class = torch.argmax(output, dim=1).item() | |
| return f"Predicted Class: {class_labels[predicted_class]}" | |
| # β **Sample Images (Replace URLs with actual hosted images or local paths)** | |
| sample_images = [ | |
| "img1.png", # Example image URLs (Replace with real ones) | |
| "img2.png", | |
| "img3.png", | |
| "img4.png", | |
| "img5.png" | |
| ] | |
| # β **Gradio Interface with Sample Images** | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="Clothing1M Image Classifier", | |
| description="Upload a clothing image or select a sample below. The model will classify it into one of the 14 categories.", | |
| examples=sample_images # β Predefined images for quick testing | |
| ) | |
| # β **Run the Interface** | |
| if __name__ == "__main__": | |
| interface.launch(debug=True) | |