| # import os | |
| # import gradio as gr | |
| # import torch | |
| # import torch.nn as nn | |
| # from torchvision import transforms | |
| # from PIL import Image | |
| # # ========================================== | |
| # # 1. YOUR CUSTOM MODEL ARCHITECTURE | |
| # # ========================================== | |
| # class BottleneckBlock(nn.Module): | |
| # expansion = 4 | |
| # def __init__(self, in_channels, mid_channels, stride=1): | |
| # super(BottleneckBlock, self).__init__() | |
| # out_channels = mid_channels * self.expansion | |
| # self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) | |
| # self.bn1 = nn.BatchNorm2d(mid_channels) | |
| # self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False) | |
| # self.bn2 = nn.BatchNorm2d(mid_channels) | |
| # self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False) | |
| # self.bn3 = nn.BatchNorm2d(out_channels) | |
| # self.relu = nn.ReLU(inplace=True) | |
| # self.shortcut = nn.Sequential() | |
| # if stride != 1 or in_channels != out_channels: | |
| # self.shortcut = nn.Sequential( | |
| # nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), | |
| # nn.BatchNorm2d(out_channels) | |
| # ) | |
| # def forward(self, x): | |
| # identity = x | |
| # out = self.conv1(x) | |
| # out = self.bn1(out) | |
| # out = self.relu(out) | |
| # out = self.conv2(out) | |
| # out = self.bn2(out) | |
| # out = self.relu(out) | |
| # out = self.conv3(out) | |
| # out = self.bn3(out) | |
| # identity = self.shortcut(identity) | |
| # out += identity | |
| # out = self.relu(out) | |
| # return out | |
| # class ResNet50(nn.Module): | |
| # def __init__(self, num_classes=16, channels_img=3): | |
| # super(ResNet50, self).__init__() | |
| # self.in_channels = 64 | |
| # self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| # self.bn1 = nn.BatchNorm2d(64) | |
| # self.relu = nn.ReLU(inplace=True) | |
| # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| # self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1) | |
| # self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2) | |
| # self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2) | |
| # self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2) | |
| # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| # self.fc = nn.Linear(512 * 4, num_classes) | |
| # def _make_layer(self, mid_channels, num_blocks, stride): | |
| # layers = [] | |
| # layers.append(BottleneckBlock(self.in_channels, mid_channels, stride)) | |
| # self.in_channels = mid_channels * 4 | |
| # for _ in range(num_blocks - 1): | |
| # layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1)) | |
| # return nn.Sequential(*layers) | |
| # def forward(self, x): | |
| # x = self.conv1(x) | |
| # x = self.bn1(x) | |
| # x = self.relu(x) | |
| # x = self.maxpool(x) | |
| # x = self.layer1(x) | |
| # x = self.layer2(x) | |
| # x = self.layer3(x) | |
| # x = self.layer4(x) | |
| # x = self.avgpool(x) | |
| # x = torch.flatten(x, 1) | |
| # x = self.fc(x) | |
| # return x | |
| # # ========================================== | |
| # # 2. CONFIG & LOADING | |
| # # ========================================== | |
| # MODEL_FILENAME = "resnet50_epoch_4.pth" | |
| # EXAMPLES_DIR = "examples" # Directory containing example images | |
| # class_names = [ | |
| # 'Advertisement', | |
| # 'Budget', | |
| # 'Email', | |
| # 'File Folder', | |
| # 'Form', | |
| # 'Handwritten', | |
| # 'Invoice', | |
| # 'Letter', | |
| # 'Memo', | |
| # 'News Article', | |
| # 'Presentation', | |
| # 'Questionnaire', | |
| # 'Resume', | |
| # 'Scientific Publication', | |
| # 'Scientific Report', | |
| # 'Specification' | |
| # ] | |
| # def load_model(): | |
| # print(f"Loading {MODEL_FILENAME}...") | |
| # model = ResNet50(num_classes=16) | |
| # try: | |
| # checkpoint = torch.load(MODEL_FILENAME, map_location=torch.device('cpu')) | |
| # if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
| # model.load_state_dict(checkpoint['state_dict']) | |
| # else: | |
| # model.load_state_dict(checkpoint) | |
| # print("Model loaded successfully.") | |
| # except FileNotFoundError: | |
| # print(f"Error: Model file '{MODEL_FILENAME}' not found. Please ensure it is in the same directory.") | |
| # # We don't exit here so the UI can still launch (though prediction will fail) | |
| # model.eval() | |
| # return model | |
| # model = load_model() | |
| # # ========================================== | |
| # # 3. PREPROCESSING & INTERFACE | |
| # # ========================================== | |
| # transform = transforms.Compose([ | |
| # transforms.Resize((224, 224)), | |
| # transforms.ToTensor(), | |
| # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| # ]) | |
| # def predict(image): | |
| # if image is None: return None | |
| # image_tensor = transform(image).unsqueeze(0) | |
| # with torch.no_grad(): | |
| # outputs = model(image_tensor) | |
| # probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
| # return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} | |
| # # --- Dynamic Example Loading Logic --- | |
| # example_list = [] | |
| # if os.path.exists(EXAMPLES_DIR): | |
| # # Sort files to keep order consistent | |
| # for file in sorted(os.listdir(EXAMPLES_DIR)): | |
| # if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): | |
| # example_list.append([os.path.join(EXAMPLES_DIR, file)]) | |
| # else: | |
| # print(f"Warning: '{EXAMPLES_DIR}' directory not found. No examples will be shown.") | |
| # # Gradio UI | |
| # interface = gr.Interface( | |
| # fn=predict, | |
| # inputs=gr.Image(type="pil"), | |
| # outputs=gr.Label(num_top_classes=3), | |
| # title="Document Classifier (ResNet50)", | |
| # description="Custom ResNet50 trained on RVL-CDIP to classify 16 document types. Click on an example below to test.", | |
| # examples=example_list if example_list else None # Handle case where list is empty | |
| # ) | |
| # if __name__ == "__main__": | |
| # interface.launch() | |
| import os | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| # ========================================== | |
| # 1. MODEL ARCHITECTURE | |
| # ========================================== | |
| class BottleneckBlock(nn.Module): | |
| expansion = 4 | |
| def __init__(self, in_channels, mid_channels, stride=1): | |
| super(BottleneckBlock, self).__init__() | |
| out_channels = mid_channels * self.expansion | |
| self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(mid_channels) | |
| self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(mid_channels) | |
| self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False) | |
| self.bn3 = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.shortcut = nn.Sequential() | |
| if stride != 1 or in_channels != out_channels: | |
| self.shortcut = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), | |
| nn.BatchNorm2d(out_channels) | |
| ) | |
| def forward(self, x): | |
| identity = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| identity = self.shortcut(identity) | |
| out += identity | |
| out = self.relu(out) | |
| return out | |
| class ResNet50(nn.Module): | |
| def __init__(self, num_classes=16, channels_img=3): | |
| super(ResNet50, self).__init__() | |
| self.in_channels = 64 | |
| self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1) | |
| self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2) | |
| self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2) | |
| self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Linear(512 * 4, num_classes) | |
| def _make_layer(self, mid_channels, num_blocks, stride): | |
| layers = [] | |
| layers.append(BottleneckBlock(self.in_channels, mid_channels, stride)) | |
| self.in_channels = mid_channels * 4 | |
| for _ in range(num_blocks - 1): | |
| layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| x = self.fc(x) | |
| return x | |
| # ========================================== | |
| # 2. GRAD-CAM CLASS (New Addition) | |
| # ========================================== | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| # Register hooks | |
| self.target_layer.register_forward_hook(self.save_activation) | |
| self.target_layer.register_full_backward_hook(self.save_gradient) | |
| def save_activation(self, module, input, output): | |
| self.activations = output | |
| def save_gradient(self, module, grad_input, grad_output): | |
| self.gradients = grad_output[0] | |
| def __call__(self, x, class_idx=None): | |
| # 1. Forward Pass | |
| output = self.model(x) | |
| if class_idx is None: | |
| class_idx = torch.argmax(output, dim=1) | |
| # 2. Backward Pass | |
| self.model.zero_grad() | |
| score = output[0, class_idx] | |
| score.backward() | |
| # 3. Generate Map | |
| gradients = self.gradients.data.cpu().numpy()[0] | |
| activations = self.activations.data.cpu().numpy()[0] | |
| weights = np.mean(gradients, axis=(1, 2)) | |
| cam = np.zeros(activations.shape[1:], dtype=np.float32) | |
| for i, w in enumerate(weights): | |
| cam += w * activations[i] | |
| cam = np.maximum(cam, 0) | |
| cam = cv2.resize(cam, (224, 224)) | |
| cam = cam - np.min(cam) | |
| if np.max(cam) != 0: | |
| cam = cam / np.max(cam) | |
| return cam, int(class_idx), output | |
| # ========================================== | |
| # 3. CONFIG & SETUP | |
| # ========================================== | |
| # ========================================== | |
| # 3. CONFIG & SETUP | |
| # ========================================== | |
| # FIXED: Removed 'models/' prefix since your file is in the root | |
| MODEL_FILENAME = "resnet50_epoch_4.pth" | |
| EXAMPLES_DIR = "examples" | |
| class_names = [ | |
| 'advertisement', 'budget', 'email', 'file folder', 'form', 'handwritten', | |
| 'invoice', 'letter', 'memo', 'news article', 'presentation', 'questionnaire', | |
| 'resume', 'scientific publication', 'scientific report', 'specification' | |
| ] | |
| # Load Model | |
| print(f"Loading {MODEL_FILENAME}...") | |
| model = ResNet50(num_classes=16) | |
| if not os.path.exists(MODEL_FILENAME): | |
| raise RuntimeError(f"CRITICAL ERROR: Model file '{MODEL_FILENAME}' not found in current directory: {os.getcwd()}") | |
| try: | |
| # We use weights_only=False because we created this file ourselves | |
| checkpoint = torch.load(MODEL_FILENAME, map_location='cpu', weights_only=False) | |
| if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load model weights: {e}") | |
| raise e # Force the app to stop if loading fails | |
| model.eval() | |
| # Initialize GradCAM (Targeting the last convolutional layer) | |
| target_layer = model.layer4[2].conv3 | |
| grad_cam = GradCAM(model, target_layer) | |
| # ========================================== | |
| # 4. PREDICTION FUNCTION | |
| # ========================================== | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(image): | |
| if image is None: return None, None | |
| # Preprocess | |
| # Ensure RGB | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| input_tensor = transform(image).unsqueeze(0) | |
| # Run GradCAM (which also runs the forward pass) | |
| cam, class_idx, logits = grad_cam(input_tensor) | |
| # Process Probabilities | |
| probabilities = torch.nn.functional.softmax(logits[0], dim=0) | |
| # confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} | |
| confidences = {class_names[i]: float(probabilities[i].detach()) for i in range(len(class_names))} | |
| # Process Heatmap Overlay | |
| heatmap = np.uint8(255 * cam) | |
| heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| # Resize original image to 224x224 to match heatmap | |
| original_cv = cv2.cvtColor(np.array(image.resize((224, 224))), cv2.COLOR_RGB2BGR) | |
| # Blend | |
| superimposed = cv2.addWeighted(original_cv, 0.6, heatmap, 0.4, 0) | |
| # Convert back to RGB for Gradio | |
| final_image = cv2.cvtColor(superimposed, cv2.COLOR_BGR2RGB) | |
| return confidences, final_image | |
| # ========================================== | |
| # 5. UI | |
| # ========================================== | |
| example_list = [] | |
| if os.path.exists(EXAMPLES_DIR): | |
| for file in sorted(os.listdir(EXAMPLES_DIR)): | |
| if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): | |
| example_list.append([os.path.join(EXAMPLES_DIR, file)]) | |
| title = "📄 Intelligent Document Classifier + Explainability" | |
| description = """ | |
| **Analyze and categorize scanned documents with AI.** This demo includes **Grad-CAM Explainability**, which highlights the specific regions the model looked at to make its decision. | |
| The model classifies 16 document types (ResNet-50 trained on RVL-CDIP). | |
| """ | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Document"), | |
| outputs=[ | |
| gr.Label(num_top_classes=3, label="Predictions"), | |
| gr.Image(label="Explainability Heatmap (What the model 'sees')") | |
| ], | |
| title=title, | |
| description=description, | |
| examples=example_list if example_list else None | |
| ) | |
| if __name__ == "__main__": | |
| # interface.launch() | |
| interface.launch(ssr_mode=False) |