Spaces:
Sleeping
Sleeping
santhoshv6
Optimize model loading to fix storage limit issues - use streaming and memory cleanup
6f86b6f
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| import os | |
| import gc # For garbage collection | |
| # CIFAR-100 class names | |
| CIFAR100_CLASSES = [ | |
| 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', | |
| 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', | |
| 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', | |
| 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', | |
| 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', | |
| 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', | |
| 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', | |
| 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', | |
| 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', | |
| 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', | |
| 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', | |
| 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', | |
| 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', | |
| 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', | |
| 'worm' | |
| ] | |
| class BasicBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, in_planes, planes, stride=1): | |
| super(BasicBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| self.shortcut = nn.Sequential() | |
| if stride != 1 or in_planes != self.expansion*planes: | |
| self.shortcut = nn.Sequential( | |
| nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | |
| nn.BatchNorm2d(self.expansion*planes) | |
| ) | |
| def forward(self, x): | |
| out = F.relu(self.bn1(self.conv1(x))) | |
| out = self.bn2(self.conv2(out)) | |
| out += self.shortcut(x) | |
| out = F.relu(out) | |
| return out | |
| class ResNet18(nn.Module): | |
| def __init__(self, num_classes=100): | |
| super(ResNet18, self).__init__() | |
| self.in_planes = 64 | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1) | |
| self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2) | |
| self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2) | |
| self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.linear = nn.Linear(512*BasicBlock.expansion, num_classes) | |
| def _make_layer(self, block, planes, num_blocks, stride): | |
| strides = [stride] + [1]*(num_blocks-1) | |
| layers = [] | |
| for stride in strides: | |
| layers.append(block(self.in_planes, planes, stride)) | |
| self.in_planes = planes * block.expansion | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| out = F.relu(self.bn1(self.conv1(x))) | |
| out = self.layer1(out) | |
| out = self.layer2(out) | |
| out = self.layer3(out) | |
| out = self.layer4(out) | |
| out = self.avgpool(out) | |
| out = out.view(out.size(0), -1) | |
| out = self.linear(out) | |
| return out | |
| # Initialize model | |
| model = ResNet18(num_classes=100) | |
| model_loaded = False | |
| model_status = "Not loaded" | |
| def load_model_with_fallbacks(): | |
| """Try multiple methods to load the model with optimized memory usage""" | |
| global model, model_loaded, model_status | |
| # Method 1: Try GitHub releases with streaming | |
| try: | |
| print("π Attempting to load model from GitHub releases...") | |
| model_url = "https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth" | |
| # Add headers to avoid being blocked | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| # Stream the download to avoid memory issues | |
| response = requests.get(model_url, headers=headers, timeout=60, stream=True) | |
| response.raise_for_status() | |
| # Load directly from stream to minimize memory usage | |
| model_data = BytesIO() | |
| for chunk in response.iter_content(chunk_size=8192): | |
| model_data.write(chunk) | |
| model_data.seek(0) | |
| print(f"β Downloaded model: {model_data.getbuffer().nbytes} bytes") | |
| # Load the model state dict with memory optimization | |
| checkpoint = torch.load(model_data, map_location='cpu') | |
| # Clear the downloaded data immediately | |
| model_data.close() | |
| del model_data | |
| if 'state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| model.eval() | |
| # Clear checkpoint data to free memory | |
| accuracy = checkpoint.get('test_acc', 77.45) | |
| del checkpoint | |
| gc.collect() # Force garbage collection | |
| model_loaded = True | |
| model_status = f"β Loaded from GitHub (Accuracy: {accuracy:.2f}%)" | |
| print(f"β Model loaded successfully! Accuracy: {accuracy:.2f}%") | |
| return True | |
| else: | |
| raise Exception("No 'state_dict' found in checkpoint") | |
| except Exception as e: | |
| print(f"β GitHub method failed: {e}") | |
| model_status = f"β Failed to load: {str(e)[:100]}..." | |
| # Method 2: Try to initialize with random weights (for demo purposes) | |
| try: | |
| print("π Initializing model with random weights for demo...") | |
| model.eval() # Set to eval mode | |
| model_loaded = True | |
| model_status = "β οΈ Demo mode (random weights) - Upload any image to test interface" | |
| return True | |
| except Exception as e: | |
| print(f"β Demo initialization failed: {e}") | |
| model_status = f"β Complete failure: {str(e)}" | |
| return False | |
| # Define image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) | |
| ]) | |
| def predict(image): | |
| """ | |
| Predict the class of an input image using the trained ResNet-18 model. | |
| """ | |
| if not model_loaded: | |
| return {"β Model Error": f"Model not loaded: {model_status}"} | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Preprocess the image | |
| input_tensor = transform(image).unsqueeze(0) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| # Get top 5 predictions | |
| top5_prob, top5_idx = torch.topk(probabilities, 5, dim=1) | |
| # Create results dictionary | |
| results = {} | |
| for i in range(5): | |
| class_idx = top5_idx[0][i].item() | |
| class_name = CIFAR100_CLASSES[class_idx] | |
| confidence = top5_prob[0][i].item() | |
| results[f"{class_name}"] = confidence | |
| return results | |
| except Exception as e: | |
| return {"β Prediction Error": f"Failed: {str(e)[:100]}..."} | |
| # Try to load model on startup | |
| print("π Starting model loading...") | |
| load_success = load_model_with_fallbacks() | |
| # Create Gradio interface | |
| def create_interface(): | |
| title = "π CIFAR-100 ResNet-18 Classifier" | |
| if "random weights" in model_status: | |
| title += " (Demo Mode)" | |
| elif not model_loaded: | |
| title = "β CIFAR-100 Model - Loading Error" | |
| description = f""" | |
| **Model Status:** {model_status} | |
| **Upload an image to classify it into one of 100 CIFAR-100 categories!** | |
| π― **Model Performance:** 77.45% test accuracy achieved | |
| ποΈ **Architecture:** ResNet-18 with 11.22M parameters | |
| π **Training:** 100 epochs on Tesla P100, reached target at epoch 58 | |
| **Best performing classes:** wardrobe (97%), motorcycle (93%), bicycle (93%), aquarium_fish (92%) | |
| *This model excels at furniture, vehicles, and distinctive objects. For best results, upload clear images similar to CIFAR-100 style.* | |
| """ | |
| if not model_loaded: | |
| description = f""" | |
| **β Model Loading Error:** {model_status} | |
| **Possible solutions:** | |
| 1. Check if GitHub release exists: https://github.com/santhoshv6/era_v4_s8_assignment/releases | |
| 2. Verify model file is accessible | |
| 3. Try refreshing the space | |
| **Expected model URL:** | |
| https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth | |
| """ | |
| return gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload an Image"), | |
| outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"), | |
| title=title, | |
| description=description, | |
| examples=[], | |
| article=""" | |
| ### π About This Model | |
| This ResNet-18 model was trained on CIFAR-100 dataset achieving **77.45% accuracy**, exceeding the 73% target by 4.45%. | |
| **Key Features:** | |
| - ποΈ **Optimized Architecture:** ResNet-18 with BasicBlocks | |
| - π¨ **Advanced Augmentation:** Albumentations + Mixup + CutMix | |
| - β‘ **Fast Training:** OneCycle learning rate scheduler | |
| - π **Interpretable:** GradCAM visualizations available | |
| **CIFAR-100 Categories:** 100 fine-grained classes across 20 superclasses including animals, vehicles, household items, and natural objects. | |
| π **Full Documentation:** [GitHub Repository](https://github.com/santhoshv6/era_v4_s8_assignment) | |
| --- | |
| **π Debugging Info:** | |
| - Model Status: {model_status} | |
| - Expected Model URL: https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth | |
| """.format(model_status=model_status), | |
| theme=gr.themes.Soft(), | |
| allow_flagging="never" | |
| ) | |
| # Create and launch the interface | |
| demo = create_interface() | |
| if __name__ == "__main__": | |
| demo.launch() |