import torch import gradio as gr import torchvision.transforms as transforms from PIL import Image from huggingface_hub import hf_hub_download import requests from io import BytesIO from resnet import SupCEResNet # Define class labels class_labels = [ "T-shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Down Coat", "Suit", "Shawl", "Dress", "Vest", "Underwear" ] # Load model from Hugging Face Hub def load_model_from_huggingface(repo_id="tfarhan10/Clothing1M-Pretrained-ResNet50", filename="model.pth"): try: print("Downloading model from Hugging Face...") checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'),weights_only=False) # Extract state_dict if stored in a dictionary if isinstance(checkpoint, dict) and "model" in checkpoint: state_dict = checkpoint["model"] else: state_dict = checkpoint # Fix "module." prefix issue new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # Initialize model model = SupCEResNet(name='resnet50', num_classes=14, pool=True) # Load weights model.load_state_dict(new_state_dict, strict=False) # `strict=False` allows minor mismatches model.eval() # Set model to evaluation mode print("Model loaded successfully from Hugging Face!") return model except Exception as e: print(f"Error loading model: {e}") return None # Load the model model = load_model_from_huggingface() def classify_image(image): """Process and classify an uploaded PIL image accurately.""" # Ensure image is in RGB format if image.mode != "RGB": image = image.convert("RGB") # Define preprocessing transformations (same as training) transform_test = transforms.Compose([ transforms.Resize(256), # Resize the shorter side to 256 transforms.CenterCrop(224), # Center crop to 224x224 (expected input size) transforms.ToTensor(), # Convert to Tensor transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # Normalize ]) # Apply transformations image_tensor = transform_test(image).unsqueeze(0) # Add batch dimension # Ensure tensor is on the same device as model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) image_tensor = image_tensor.to(device) # Run inference with torch.no_grad(): output = model(image_tensor) _, pred = torch.max(output, 1) # Get predicted class index # Map predicted class index to label predicted_label = class_labels[pred.item()] return f"Predicted Category: {predicted_label}" # Load example image from Hugging Face repository example_url = "https://huggingface.co/tfarhan10/Clothing1M-Pretrained-ResNet50/resolve/main/content/drive/MyDrive/CS5930/download.jpeg" def load_example_image(): """Download and return an example image from Hugging Face""" try: response = requests.get(example_url) if response.status_code == 200: return Image.open(BytesIO(response.content)).convert("RGB") else: print("Failed to fetch example image.") return None except Exception as e: print(f"Error loading example image: {e}") return None # Example image example_image = load_example_image() # Create Gradio Interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), # Accept image input outputs="text", title="Clothing Image Classifier", description="Upload an image or use the example below. The model will classify it into one of 14 clothing categories.", allow_flagging="never", # Disable flagging feature examples=[[example_image]] if example_image else None # Use example image if available ) # Launch the app if __name__ == "__main__": interface.launch()