arpit-gour02's picture
Update app.py
6161904 verified
# import os
# import gradio as gr
# import torch
# import torch.nn as nn
# from torchvision import transforms
# from PIL import Image
# # ==========================================
# # 1. YOUR CUSTOM MODEL ARCHITECTURE
# # ==========================================
# class BottleneckBlock(nn.Module):
# expansion = 4
# def __init__(self, in_channels, mid_channels, stride=1):
# super(BottleneckBlock, self).__init__()
# out_channels = mid_channels * self.expansion
# self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
# self.bn1 = nn.BatchNorm2d(mid_channels)
# self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False)
# self.bn2 = nn.BatchNorm2d(mid_channels)
# self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
# self.bn3 = nn.BatchNorm2d(out_channels)
# self.relu = nn.ReLU(inplace=True)
# self.shortcut = nn.Sequential()
# if stride != 1 or in_channels != out_channels:
# self.shortcut = nn.Sequential(
# nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(out_channels)
# )
# def forward(self, x):
# identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# out = self.relu(out)
# out = self.conv3(out)
# out = self.bn3(out)
# identity = self.shortcut(identity)
# out += identity
# out = self.relu(out)
# return out
# class ResNet50(nn.Module):
# def __init__(self, num_classes=16, channels_img=3):
# super(ResNet50, self).__init__()
# self.in_channels = 64
# self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False)
# self.bn1 = nn.BatchNorm2d(64)
# self.relu = nn.ReLU(inplace=True)
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1)
# self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2)
# self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2)
# self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2)
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# self.fc = nn.Linear(512 * 4, num_classes)
# def _make_layer(self, mid_channels, num_blocks, stride):
# layers = []
# layers.append(BottleneckBlock(self.in_channels, mid_channels, stride))
# self.in_channels = mid_channels * 4
# for _ in range(num_blocks - 1):
# layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1))
# return nn.Sequential(*layers)
# def forward(self, x):
# x = self.conv1(x)
# x = self.bn1(x)
# x = self.relu(x)
# x = self.maxpool(x)
# x = self.layer1(x)
# x = self.layer2(x)
# x = self.layer3(x)
# x = self.layer4(x)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
# return x
# # ==========================================
# # 2. CONFIG & LOADING
# # ==========================================
# MODEL_FILENAME = "resnet50_epoch_4.pth"
# EXAMPLES_DIR = "examples" # Directory containing example images
# class_names = [
# 'Advertisement',
# 'Budget',
# 'Email',
# 'File Folder',
# 'Form',
# 'Handwritten',
# 'Invoice',
# 'Letter',
# 'Memo',
# 'News Article',
# 'Presentation',
# 'Questionnaire',
# 'Resume',
# 'Scientific Publication',
# 'Scientific Report',
# 'Specification'
# ]
# def load_model():
# print(f"Loading {MODEL_FILENAME}...")
# model = ResNet50(num_classes=16)
# try:
# checkpoint = torch.load(MODEL_FILENAME, map_location=torch.device('cpu'))
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
# model.load_state_dict(checkpoint['state_dict'])
# else:
# model.load_state_dict(checkpoint)
# print("Model loaded successfully.")
# except FileNotFoundError:
# print(f"Error: Model file '{MODEL_FILENAME}' not found. Please ensure it is in the same directory.")
# # We don't exit here so the UI can still launch (though prediction will fail)
# model.eval()
# return model
# model = load_model()
# # ==========================================
# # 3. PREPROCESSING & INTERFACE
# # ==========================================
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
# def predict(image):
# if image is None: return None
# image_tensor = transform(image).unsqueeze(0)
# with torch.no_grad():
# outputs = model(image_tensor)
# probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
# return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
# # --- Dynamic Example Loading Logic ---
# example_list = []
# if os.path.exists(EXAMPLES_DIR):
# # Sort files to keep order consistent
# for file in sorted(os.listdir(EXAMPLES_DIR)):
# if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
# example_list.append([os.path.join(EXAMPLES_DIR, file)])
# else:
# print(f"Warning: '{EXAMPLES_DIR}' directory not found. No examples will be shown.")
# # Gradio UI
# interface = gr.Interface(
# fn=predict,
# inputs=gr.Image(type="pil"),
# outputs=gr.Label(num_top_classes=3),
# title="Document Classifier (ResNet50)",
# description="Custom ResNet50 trained on RVL-CDIP to classify 16 document types. Click on an example below to test.",
# examples=example_list if example_list else None # Handle case where list is empty
# )
# if __name__ == "__main__":
# interface.launch()
import os
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
# ==========================================
# 1. MODEL ARCHITECTURE
# ==========================================
class BottleneckBlock(nn.Module):
expansion = 4
def __init__(self, in_channels, mid_channels, stride=1):
super(BottleneckBlock, self).__init__()
out_channels = mid_channels * self.expansion
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(mid_channels)
self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
identity = self.shortcut(identity)
out += identity
out = self.relu(out)
return out
class ResNet50(nn.Module):
def __init__(self, num_classes=16, channels_img=3):
super(ResNet50, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1)
self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2)
self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2)
self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * 4, num_classes)
def _make_layer(self, mid_channels, num_blocks, stride):
layers = []
layers.append(BottleneckBlock(self.in_channels, mid_channels, stride))
self.in_channels = mid_channels * 4
for _ in range(num_blocks - 1):
layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# ==========================================
# 2. GRAD-CAM CLASS (New Addition)
# ==========================================
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
# Register hooks
self.target_layer.register_forward_hook(self.save_activation)
self.target_layer.register_full_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
self.activations = output
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0]
def __call__(self, x, class_idx=None):
# 1. Forward Pass
output = self.model(x)
if class_idx is None:
class_idx = torch.argmax(output, dim=1)
# 2. Backward Pass
self.model.zero_grad()
score = output[0, class_idx]
score.backward()
# 3. Generate Map
gradients = self.gradients.data.cpu().numpy()[0]
activations = self.activations.data.cpu().numpy()[0]
weights = np.mean(gradients, axis=(1, 2))
cam = np.zeros(activations.shape[1:], dtype=np.float32)
for i, w in enumerate(weights):
cam += w * activations[i]
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (224, 224))
cam = cam - np.min(cam)
if np.max(cam) != 0:
cam = cam / np.max(cam)
return cam, int(class_idx), output
# ==========================================
# 3. CONFIG & SETUP
# ==========================================
# ==========================================
# 3. CONFIG & SETUP
# ==========================================
# FIXED: Removed 'models/' prefix since your file is in the root
MODEL_FILENAME = "resnet50_epoch_4.pth"
EXAMPLES_DIR = "examples"
class_names = [
'advertisement', 'budget', 'email', 'file folder', 'form', 'handwritten',
'invoice', 'letter', 'memo', 'news article', 'presentation', 'questionnaire',
'resume', 'scientific publication', 'scientific report', 'specification'
]
# Load Model
print(f"Loading {MODEL_FILENAME}...")
model = ResNet50(num_classes=16)
if not os.path.exists(MODEL_FILENAME):
raise RuntimeError(f"CRITICAL ERROR: Model file '{MODEL_FILENAME}' not found in current directory: {os.getcwd()}")
try:
# We use weights_only=False because we created this file ourselves
checkpoint = torch.load(MODEL_FILENAME, map_location='cpu', weights_only=False)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model weights: {e}")
raise e # Force the app to stop if loading fails
model.eval()
# Initialize GradCAM (Targeting the last convolutional layer)
target_layer = model.layer4[2].conv3
grad_cam = GradCAM(model, target_layer)
# ==========================================
# 4. PREDICTION FUNCTION
# ==========================================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image):
if image is None: return None, None
# Preprocess
# Ensure RGB
if image.mode != "RGB":
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0)
# Run GradCAM (which also runs the forward pass)
cam, class_idx, logits = grad_cam(input_tensor)
# Process Probabilities
probabilities = torch.nn.functional.softmax(logits[0], dim=0)
# confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
confidences = {class_names[i]: float(probabilities[i].detach()) for i in range(len(class_names))}
# Process Heatmap Overlay
heatmap = np.uint8(255 * cam)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Resize original image to 224x224 to match heatmap
original_cv = cv2.cvtColor(np.array(image.resize((224, 224))), cv2.COLOR_RGB2BGR)
# Blend
superimposed = cv2.addWeighted(original_cv, 0.6, heatmap, 0.4, 0)
# Convert back to RGB for Gradio
final_image = cv2.cvtColor(superimposed, cv2.COLOR_BGR2RGB)
return confidences, final_image
# ==========================================
# 5. UI
# ==========================================
example_list = []
if os.path.exists(EXAMPLES_DIR):
for file in sorted(os.listdir(EXAMPLES_DIR)):
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
example_list.append([os.path.join(EXAMPLES_DIR, file)])
title = "📄 Intelligent Document Classifier + Explainability"
description = """
**Analyze and categorize scanned documents with AI.** This demo includes **Grad-CAM Explainability**, which highlights the specific regions the model looked at to make its decision.
The model classifies 16 document types (ResNet-50 trained on RVL-CDIP).
"""
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Document"),
outputs=[
gr.Label(num_top_classes=3, label="Predictions"),
gr.Image(label="Explainability Heatmap (What the model 'sees')")
],
title=title,
description=description,
examples=example_list if example_list else None
)
if __name__ == "__main__":
# interface.launch()
interface.launch(ssr_mode=False)