import gradio as gr import torch import torchvision.transforms as transforms import torch.nn.functional as F from PIL import Image from ResNet_for_CC import CC_model # Ensure correct model import # Set device (CPU/GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained CC_model model_path = "CC_net.pt" # Update path if necessary model = CC_model(num_classes1=14) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() # Define Clothing1M Class Labels class_labels = [ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", "Vest", "Underwear" ] # Preprocess images transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Inference function with confidence scores def classify_image(image): image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image_tensor) probabilities = F.softmax(output, dim=1).cpu().numpy()[0] predicted_idx = probabilities.argmax() predicted_label = class_labels[predicted_idx] confidence = probabilities[predicted_idx] # Prepare a readable confidence interval confidence_pct = round(confidence * 100, 2) result = f"Predicted Class: {predicted_label}\nConfidence: {confidence_pct}%" return result # Example images for the Gradio Interface (upload these images to your Hugging Face Space) example_images = [ "img1.png", "img2.png", "img3.png", "img4.png", "img5.png" ] # Gradio Interface including confidence intervals interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Prediction and Confidence"), title="Clothing1M Image Classifier with Confidence Interval", description="Upload an image or select from examples to classify it and view the confidence percentage.", examples=example_images, cache_examples=False ) # Launch the interface if __name__ == "__main__": interface.launch()