import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import numpy as np import requests from io import BytesIO import os import gc # For garbage collection # CIFAR-100 class names CIFAR100_CLASSES = [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet18(nn.Module): def __init__(self, num_classes=100): super(ResNet18, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1) self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2) self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2) self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.linear = nn.Linear(512*BasicBlock.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = self.avgpool(out) out = out.view(out.size(0), -1) out = self.linear(out) return out # Initialize model model = ResNet18(num_classes=100) model_loaded = False model_status = "Not loaded" def load_model_with_fallbacks(): """Try multiple methods to load the model with optimized memory usage""" global model, model_loaded, model_status # Method 1: Try GitHub releases with streaming try: print("🔄 Attempting to load model from GitHub releases...") model_url = "https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth" # Add headers to avoid being blocked headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } # Stream the download to avoid memory issues response = requests.get(model_url, headers=headers, timeout=60, stream=True) response.raise_for_status() # Load directly from stream to minimize memory usage model_data = BytesIO() for chunk in response.iter_content(chunk_size=8192): model_data.write(chunk) model_data.seek(0) print(f"✅ Downloaded model: {model_data.getbuffer().nbytes} bytes") # Load the model state dict with memory optimization checkpoint = torch.load(model_data, map_location='cpu') # Clear the downloaded data immediately model_data.close() del model_data if 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) model.eval() # Clear checkpoint data to free memory accuracy = checkpoint.get('test_acc', 77.45) del checkpoint gc.collect() # Force garbage collection model_loaded = True model_status = f"✅ Loaded from GitHub (Accuracy: {accuracy:.2f}%)" print(f"✅ Model loaded successfully! Accuracy: {accuracy:.2f}%") return True else: raise Exception("No 'state_dict' found in checkpoint") except Exception as e: print(f"❌ GitHub method failed: {e}") model_status = f"❌ Failed to load: {str(e)[:100]}..." # Method 2: Try to initialize with random weights (for demo purposes) try: print("🔄 Initializing model with random weights for demo...") model.eval() # Set to eval mode model_loaded = True model_status = "⚠️ Demo mode (random weights) - Upload any image to test interface" return True except Exception as e: print(f"❌ Demo initialization failed: {e}") model_status = f"❌ Complete failure: {str(e)}" return False # Define image preprocessing transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ]) def predict(image): """ Predict the class of an input image using the trained ResNet-18 model. """ if not model_loaded: return {"❌ Model Error": f"Model not loaded: {model_status}"} try: # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Preprocess the image input_tensor = transform(image).unsqueeze(0) # Make prediction with torch.no_grad(): outputs = model(input_tensor) probabilities = F.softmax(outputs, dim=1) # Get top 5 predictions top5_prob, top5_idx = torch.topk(probabilities, 5, dim=1) # Create results dictionary results = {} for i in range(5): class_idx = top5_idx[0][i].item() class_name = CIFAR100_CLASSES[class_idx] confidence = top5_prob[0][i].item() results[f"{class_name}"] = confidence return results except Exception as e: return {"❌ Prediction Error": f"Failed: {str(e)[:100]}..."} # Try to load model on startup print("🚀 Starting model loading...") load_success = load_model_with_fallbacks() # Create Gradio interface def create_interface(): title = "🏆 CIFAR-100 ResNet-18 Classifier" if "random weights" in model_status: title += " (Demo Mode)" elif not model_loaded: title = "❌ CIFAR-100 Model - Loading Error" description = f""" **Model Status:** {model_status} **Upload an image to classify it into one of 100 CIFAR-100 categories!** 🎯 **Model Performance:** 77.45% test accuracy achieved 🏗️ **Architecture:** ResNet-18 with 11.22M parameters 📊 **Training:** 100 epochs on Tesla P100, reached target at epoch 58 **Best performing classes:** wardrobe (97%), motorcycle (93%), bicycle (93%), aquarium_fish (92%) *This model excels at furniture, vehicles, and distinctive objects. For best results, upload clear images similar to CIFAR-100 style.* """ if not model_loaded: description = f""" **❌ Model Loading Error:** {model_status} **Possible solutions:** 1. Check if GitHub release exists: https://github.com/santhoshv6/era_v4_s8_assignment/releases 2. Verify model file is accessible 3. Try refreshing the space **Expected model URL:** https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth """ return gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload an Image"), outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"), title=title, description=description, examples=[], article=""" ### 📚 About This Model This ResNet-18 model was trained on CIFAR-100 dataset achieving **77.45% accuracy**, exceeding the 73% target by 4.45%. **Key Features:** - 🏗️ **Optimized Architecture:** ResNet-18 with BasicBlocks - 🎨 **Advanced Augmentation:** Albumentations + Mixup + CutMix - ⚡ **Fast Training:** OneCycle learning rate scheduler - 🔍 **Interpretable:** GradCAM visualizations available **CIFAR-100 Categories:** 100 fine-grained classes across 20 superclasses including animals, vehicles, household items, and natural objects. 📖 **Full Documentation:** [GitHub Repository](https://github.com/santhoshv6/era_v4_s8_assignment) --- **🐛 Debugging Info:** - Model Status: {model_status} - Expected Model URL: https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth """.format(model_status=model_status), theme=gr.themes.Soft(), allow_flagging="never" ) # Create and launch the interface demo = create_interface() if __name__ == "__main__": demo.launch()