# import os # import gradio as gr # import torch # import torch.nn as nn # from torchvision import transforms # from PIL import Image # # ========================================== # # 1. YOUR CUSTOM MODEL ARCHITECTURE # # ========================================== # class BottleneckBlock(nn.Module): # expansion = 4 # def __init__(self, in_channels, mid_channels, stride=1): # super(BottleneckBlock, self).__init__() # out_channels = mid_channels * self.expansion # self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) # self.bn1 = nn.BatchNorm2d(mid_channels) # self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False) # self.bn2 = nn.BatchNorm2d(mid_channels) # self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False) # self.bn3 = nn.BatchNorm2d(out_channels) # self.relu = nn.ReLU(inplace=True) # self.shortcut = nn.Sequential() # if stride != 1 or in_channels != out_channels: # self.shortcut = nn.Sequential( # nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), # nn.BatchNorm2d(out_channels) # ) # def forward(self, x): # identity = x # out = self.conv1(x) # out = self.bn1(out) # out = self.relu(out) # out = self.conv2(out) # out = self.bn2(out) # out = self.relu(out) # out = self.conv3(out) # out = self.bn3(out) # identity = self.shortcut(identity) # out += identity # out = self.relu(out) # return out # class ResNet50(nn.Module): # def __init__(self, num_classes=16, channels_img=3): # super(ResNet50, self).__init__() # self.in_channels = 64 # self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False) # self.bn1 = nn.BatchNorm2d(64) # self.relu = nn.ReLU(inplace=True) # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1) # self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2) # self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2) # self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2) # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # self.fc = nn.Linear(512 * 4, num_classes) # def _make_layer(self, mid_channels, num_blocks, stride): # layers = [] # layers.append(BottleneckBlock(self.in_channels, mid_channels, stride)) # self.in_channels = mid_channels * 4 # for _ in range(num_blocks - 1): # layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1)) # return nn.Sequential(*layers) # def forward(self, x): # x = self.conv1(x) # x = self.bn1(x) # x = self.relu(x) # x = self.maxpool(x) # x = self.layer1(x) # x = self.layer2(x) # x = self.layer3(x) # x = self.layer4(x) # x = self.avgpool(x) # x = torch.flatten(x, 1) # x = self.fc(x) # return x # # ========================================== # # 2. CONFIG & LOADING # # ========================================== # MODEL_FILENAME = "resnet50_epoch_4.pth" # EXAMPLES_DIR = "examples" # Directory containing example images # class_names = [ # 'Advertisement', # 'Budget', # 'Email', # 'File Folder', # 'Form', # 'Handwritten', # 'Invoice', # 'Letter', # 'Memo', # 'News Article', # 'Presentation', # 'Questionnaire', # 'Resume', # 'Scientific Publication', # 'Scientific Report', # 'Specification' # ] # def load_model(): # print(f"Loading {MODEL_FILENAME}...") # model = ResNet50(num_classes=16) # try: # checkpoint = torch.load(MODEL_FILENAME, map_location=torch.device('cpu')) # if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: # model.load_state_dict(checkpoint['state_dict']) # else: # model.load_state_dict(checkpoint) # print("Model loaded successfully.") # except FileNotFoundError: # print(f"Error: Model file '{MODEL_FILENAME}' not found. Please ensure it is in the same directory.") # # We don't exit here so the UI can still launch (though prediction will fail) # model.eval() # return model # model = load_model() # # ========================================== # # 3. PREPROCESSING & INTERFACE # # ========================================== # transform = transforms.Compose([ # transforms.Resize((224, 224)), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ]) # def predict(image): # if image is None: return None # image_tensor = transform(image).unsqueeze(0) # with torch.no_grad(): # outputs = model(image_tensor) # probabilities = torch.nn.functional.softmax(outputs[0], dim=0) # return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} # # --- Dynamic Example Loading Logic --- # example_list = [] # if os.path.exists(EXAMPLES_DIR): # # Sort files to keep order consistent # for file in sorted(os.listdir(EXAMPLES_DIR)): # if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): # example_list.append([os.path.join(EXAMPLES_DIR, file)]) # else: # print(f"Warning: '{EXAMPLES_DIR}' directory not found. No examples will be shown.") # # Gradio UI # interface = gr.Interface( # fn=predict, # inputs=gr.Image(type="pil"), # outputs=gr.Label(num_top_classes=3), # title="Document Classifier (ResNet50)", # description="Custom ResNet50 trained on RVL-CDIP to classify 16 document types. Click on an example below to test.", # examples=example_list if example_list else None # Handle case where list is empty # ) # if __name__ == "__main__": # interface.launch() import os import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np import cv2 # ========================================== # 1. MODEL ARCHITECTURE # ========================================== class BottleneckBlock(nn.Module): expansion = 4 def __init__(self, in_channels, mid_channels, stride=1): super(BottleneckBlock, self).__init__() out_channels = mid_channels * self.expansion self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(mid_channels) self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(mid_channels) self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) identity = self.shortcut(identity) out += identity out = self.relu(out) return out class ResNet50(nn.Module): def __init__(self, num_classes=16, channels_img=3): super(ResNet50, self).__init__() self.in_channels = 64 self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1) self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2) self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2) self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * 4, num_classes) def _make_layer(self, mid_channels, num_blocks, stride): layers = [] layers.append(BottleneckBlock(self.in_channels, mid_channels, stride)) self.in_channels = mid_channels * 4 for _ in range(num_blocks - 1): layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x # ========================================== # 2. GRAD-CAM CLASS (New Addition) # ========================================== class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None # Register hooks self.target_layer.register_forward_hook(self.save_activation) self.target_layer.register_full_backward_hook(self.save_gradient) def save_activation(self, module, input, output): self.activations = output def save_gradient(self, module, grad_input, grad_output): self.gradients = grad_output[0] def __call__(self, x, class_idx=None): # 1. Forward Pass output = self.model(x) if class_idx is None: class_idx = torch.argmax(output, dim=1) # 2. Backward Pass self.model.zero_grad() score = output[0, class_idx] score.backward() # 3. Generate Map gradients = self.gradients.data.cpu().numpy()[0] activations = self.activations.data.cpu().numpy()[0] weights = np.mean(gradients, axis=(1, 2)) cam = np.zeros(activations.shape[1:], dtype=np.float32) for i, w in enumerate(weights): cam += w * activations[i] cam = np.maximum(cam, 0) cam = cv2.resize(cam, (224, 224)) cam = cam - np.min(cam) if np.max(cam) != 0: cam = cam / np.max(cam) return cam, int(class_idx), output # ========================================== # 3. CONFIG & SETUP # ========================================== # ========================================== # 3. CONFIG & SETUP # ========================================== # FIXED: Removed 'models/' prefix since your file is in the root MODEL_FILENAME = "resnet50_epoch_4.pth" EXAMPLES_DIR = "examples" class_names = [ 'advertisement', 'budget', 'email', 'file folder', 'form', 'handwritten', 'invoice', 'letter', 'memo', 'news article', 'presentation', 'questionnaire', 'resume', 'scientific publication', 'scientific report', 'specification' ] # Load Model print(f"Loading {MODEL_FILENAME}...") model = ResNet50(num_classes=16) if not os.path.exists(MODEL_FILENAME): raise RuntimeError(f"CRITICAL ERROR: Model file '{MODEL_FILENAME}' not found in current directory: {os.getcwd()}") try: # We use weights_only=False because we created this file ourselves checkpoint = torch.load(MODEL_FILENAME, map_location='cpu', weights_only=False) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) print("Model loaded successfully.") except Exception as e: print(f"Failed to load model weights: {e}") raise e # Force the app to stop if loading fails model.eval() # Initialize GradCAM (Targeting the last convolutional layer) target_layer = model.layer4[2].conv3 grad_cam = GradCAM(model, target_layer) # ========================================== # 4. PREDICTION FUNCTION # ========================================== transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(image): if image is None: return None, None # Preprocess # Ensure RGB if image.mode != "RGB": image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0) # Run GradCAM (which also runs the forward pass) cam, class_idx, logits = grad_cam(input_tensor) # Process Probabilities probabilities = torch.nn.functional.softmax(logits[0], dim=0) # confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} confidences = {class_names[i]: float(probabilities[i].detach()) for i in range(len(class_names))} # Process Heatmap Overlay heatmap = np.uint8(255 * cam) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # Resize original image to 224x224 to match heatmap original_cv = cv2.cvtColor(np.array(image.resize((224, 224))), cv2.COLOR_RGB2BGR) # Blend superimposed = cv2.addWeighted(original_cv, 0.6, heatmap, 0.4, 0) # Convert back to RGB for Gradio final_image = cv2.cvtColor(superimposed, cv2.COLOR_BGR2RGB) return confidences, final_image # ========================================== # 5. UI # ========================================== example_list = [] if os.path.exists(EXAMPLES_DIR): for file in sorted(os.listdir(EXAMPLES_DIR)): if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): example_list.append([os.path.join(EXAMPLES_DIR, file)]) title = "📄 Intelligent Document Classifier + Explainability" description = """ **Analyze and categorize scanned documents with AI.** This demo includes **Grad-CAM Explainability**, which highlights the specific regions the model looked at to make its decision. The model classifies 16 document types (ResNet-50 trained on RVL-CDIP). """ interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Document"), outputs=[ gr.Label(num_top_classes=3, label="Predictions"), gr.Image(label="Explainability Heatmap (What the model 'sees')") ], title=title, description=description, examples=example_list if example_list else None ) if __name__ == "__main__": # interface.launch() interface.launch(ssr_mode=False)