arpit-gour02 commited on
Commit
14fe11e
·
unverified ·
1 Parent(s): 6239f77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -30
app.py CHANGED
@@ -1,72 +1,164 @@
1
  import gradio as gr
2
  import torch
3
- from torchvision import models, transforms
 
 
4
 
5
- # --- 1. CONFIGURATION ---
6
- MODEL_FILENAME = "resnet50_epoch_4.pth" # Loading locally now
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class_names = [
9
  'letter', 'form', 'email', 'handwritten', 'advertisement', 'scientific report',
10
  'scientific publication', 'specification', 'file folder', 'news article',
11
  'budget', 'invoice', 'presentation', 'questionnaire', 'resume', 'memo'
12
  ]
13
 
14
- # --- 2. LOAD MODEL LOCALLY ---
15
- def load_model_locally():
16
- print(f"Loading {MODEL_FILENAME} from local disk...")
17
 
18
- # Initialize Standard Architecture
19
- model = models.resnet50(num_classes=16)
20
 
21
- # Load the checkpoint locally
22
  checkpoint = torch.load(MODEL_FILENAME, map_location=torch.device('cpu'))
23
 
24
- # Handle if it's nested in 'state_dict'
25
  if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
26
- state_dict = checkpoint['state_dict']
27
  else:
28
- state_dict = checkpoint
29
-
30
- # --- THE FIX: RENAME KEYS ---
31
- # We must still rename 'shortcut' -> 'downsample' because your file
32
- # has custom names, but we are using the standard torchvision model here.
33
- new_state_dict = {}
34
- for key, value in state_dict.items():
35
- new_key = key.replace("shortcut", "downsample")
36
- new_state_dict[new_key] = value
37
- # ----------------------------
38
-
39
- model.load_state_dict(new_state_dict)
40
  model.eval()
41
  return model
42
 
43
- model = load_model_locally()
 
44
 
45
- # --- 3. PREPROCESSING ---
 
 
 
 
46
  transform = transforms.Compose([
47
  transforms.Resize((224, 224)),
48
  transforms.ToTensor(),
49
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
  ])
51
 
52
- # --- 4. PREDICTION FUNCTION ---
53
  def predict(image):
54
- if image is None:
55
- return None
56
  image_tensor = transform(image).unsqueeze(0)
 
57
  with torch.no_grad():
58
  outputs = model(image_tensor)
59
  probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
60
 
61
  return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
62
 
63
- # --- 5. LAUNCH INTERFACE ---
64
  interface = gr.Interface(
65
  fn=predict,
66
  inputs=gr.Image(type="pil"),
67
  outputs=gr.Label(num_top_classes=3),
68
  title="Document Classifier (ResNet50)",
69
- description="Classifies documents into 16 categories.",
70
  examples=[
71
  ["1.png"],
72
  ["5022.png"],
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
 
7
+ # ==========================================
8
+ # 1. YOUR CUSTOM MODEL ARCHITECTURE
9
+ # ==========================================
10
 
11
+ class BottleneckBlock(nn.Module):
12
+ expansion = 4
13
+
14
+ def __init__(self, in_channels, mid_channels, stride=1):
15
+ super(BottleneckBlock, self).__init__()
16
+
17
+ out_channels = mid_channels * self.expansion
18
+
19
+ self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
20
+ self.bn1 = nn.BatchNorm2d(mid_channels)
21
+
22
+ self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False)
23
+ self.bn2 = nn.BatchNorm2d(mid_channels)
24
+
25
+ self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(out_channels)
27
+
28
+ self.relu = nn.ReLU(inplace=True)
29
+
30
+ self.shortcut = nn.Sequential()
31
+
32
+ if stride != 1 or in_channels != out_channels:
33
+ self.shortcut = nn.Sequential(
34
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
35
+ nn.BatchNorm2d(out_channels)
36
+ )
37
+
38
+ def forward(self, x):
39
+ identity = x
40
+
41
+ out = self.conv1(x)
42
+ out = self.bn1(out)
43
+ out = self.relu(out)
44
+
45
+ out = self.conv2(out)
46
+ out = self.bn2(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv3(out)
50
+ out = self.bn3(out)
51
+
52
+ identity = self.shortcut(identity)
53
+
54
+ out += identity
55
+ out = self.relu(out)
56
+
57
+ return out
58
+
59
+ class ResNet50(nn.Module):
60
+ def __init__(self, num_classes=16, channels_img=3):
61
+ super(ResNet50, self).__init__()
62
+
63
+ self.in_channels = 64
64
+
65
+ self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False)
66
+ self.bn1 = nn.BatchNorm2d(64)
67
+ self.relu = nn.ReLU(inplace=True)
68
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
69
+
70
+ self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1)
71
+ self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2)
72
+ self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2)
73
+ self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2)
74
+
75
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
76
+ self.fc = nn.Linear(512 * 4, num_classes)
77
+
78
+ def _make_layer(self, mid_channels, num_blocks, stride):
79
+ layers = []
80
+ layers.append(BottleneckBlock(self.in_channels, mid_channels, stride))
81
+ self.in_channels = mid_channels * 4
82
+ for _ in range(num_blocks - 1):
83
+ layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1))
84
+ return nn.Sequential(*layers)
85
+
86
+ def forward(self, x):
87
+ x = self.conv1(x)
88
+ x = self.bn1(x)
89
+ x = self.relu(x)
90
+ x = self.maxpool(x)
91
+
92
+ x = self.layer1(x)
93
+ x = self.layer2(x)
94
+ x = self.layer3(x)
95
+ x = self.layer4(x)
96
+
97
+ x = self.avgpool(x)
98
+ x = torch.flatten(x, 1)
99
+ x = self.fc(x)
100
+ return x
101
+
102
+ # ==========================================
103
+ # 2. CONFIG & LOADING
104
+ # ==========================================
105
+
106
+ MODEL_FILENAME = "resnet50_epoch_5.pth"
107
  class_names = [
108
  'letter', 'form', 'email', 'handwritten', 'advertisement', 'scientific report',
109
  'scientific publication', 'specification', 'file folder', 'news article',
110
  'budget', 'invoice', 'presentation', 'questionnaire', 'resume', 'memo'
111
  ]
112
 
113
+ def load_model():
114
+ print(f"Loading {MODEL_FILENAME}...")
 
115
 
116
+ # Initialize YOUR Custom ResNet50
117
+ model = ResNet50(num_classes=16)
118
 
119
+ # Load weights (CPU is sufficient for inference)
120
  checkpoint = torch.load(MODEL_FILENAME, map_location=torch.device('cpu'))
121
 
122
+ # Handle dictionary nesting if present
123
  if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
124
+ model.load_state_dict(checkpoint['state_dict'])
125
  else:
126
+ model.load_state_dict(checkpoint)
127
+
 
 
 
 
 
 
 
 
 
 
128
  model.eval()
129
  return model
130
 
131
+ # Load the model once at startup
132
+ model = load_model()
133
 
134
+ # ==========================================
135
+ # 3. PREPROCESSING & INTERFACE
136
+ # ==========================================
137
+
138
+ # Standard ImageNet transforms
139
  transform = transforms.Compose([
140
  transforms.Resize((224, 224)),
141
  transforms.ToTensor(),
142
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
143
  ])
144
 
 
145
  def predict(image):
146
+ if image is None: return None
 
147
  image_tensor = transform(image).unsqueeze(0)
148
+
149
  with torch.no_grad():
150
  outputs = model(image_tensor)
151
  probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
152
 
153
  return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
154
 
155
+ # Gradio UI
156
  interface = gr.Interface(
157
  fn=predict,
158
  inputs=gr.Image(type="pil"),
159
  outputs=gr.Label(num_top_classes=3),
160
  title="Document Classifier (ResNet50)",
161
+ description="Custom ResNet50 trained on RVL-CDIP to classify 16 document types.",
162
  examples=[
163
  ["1.png"],
164
  ["5022.png"],