Karthikraj Sivakumar commited on
Commit
f9929da
·
1 Parent(s): 7070853

first commit

Browse files
Files changed (2) hide show
  1. app.py +235 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ # ==========================================
9
+ # 1. Model Architecture (Copy from notebook)
10
+ # ==========================================
11
+
12
+ class ResBlock(nn.Module):
13
+ def __init__(self, in_channels, out_channels, stride=1):
14
+ super().__init__()
15
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
16
+ stride=stride, padding=1, bias=False)
17
+ self.bn1 = nn.BatchNorm2d(out_channels)
18
+ self.relu = nn.ReLU(inplace=True)
19
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
20
+ stride=1, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(out_channels)
22
+
23
+ self.shortcut = nn.Sequential()
24
+ if stride != 1 or in_channels != out_channels:
25
+ self.shortcut = nn.Sequential(
26
+ nn.Conv2d(in_channels, out_channels, kernel_size=1,
27
+ stride=stride, bias=False),
28
+ nn.BatchNorm2d(out_channels)
29
+ )
30
+
31
+ def forward(self, x):
32
+ out = self.relu(self.bn1(self.conv1(x)))
33
+ out = self.bn2(self.conv2(out))
34
+ out += self.shortcut(x)
35
+ out = self.relu(out)
36
+ return out
37
+
38
+ class CRNN(nn.Module):
39
+ def __init__(self, num_classes, img_height=80, img_width=280, hidden_size=128):
40
+ super().__init__()
41
+
42
+ # CNN layers
43
+ self.conv1 = nn.Sequential(
44
+ nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False),
45
+ nn.BatchNorm2d(64),
46
+ nn.ReLU(inplace=True)
47
+ )
48
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
49
+
50
+ self.layer1 = ResBlock(64, 128)
51
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
52
+
53
+ self.layer2 = ResBlock(128, 256)
54
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
55
+
56
+ self.layer3 = ResBlock(256, 512)
57
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
58
+
59
+ self.dropout = nn.Dropout2d(0.2)
60
+
61
+ # RNN layers
62
+ rnn_input_size = 512 * 5
63
+ self.rnn = nn.LSTM(rnn_input_size, hidden_size, num_layers=2,
64
+ bidirectional=True, dropout=0.1, batch_first=False)
65
+
66
+ # FC layer
67
+ self.fc = nn.Linear(hidden_size * 2, num_classes)
68
+ self.log_softmax = nn.LogSoftmax(dim=2)
69
+
70
+ def forward(self, x):
71
+ x = self.conv1(x)
72
+ x = self.pool1(x)
73
+ x = self.layer1(x)
74
+ x = self.pool2(x)
75
+ x = self.layer2(x)
76
+ x = self.pool3(x)
77
+ x = self.layer3(x)
78
+ x = self.pool4(x)
79
+ conv_out = self.dropout(x)
80
+
81
+ batch_size, channels, height, width = conv_out.size()
82
+ conv_out = conv_out.view(batch_size, channels * height, width)
83
+ conv_out = conv_out.permute(2, 0, 1)
84
+
85
+ rnn_out, _ = self.rnn(conv_out)
86
+ output = self.fc(rnn_out)
87
+ log_probs = self.log_softmax(output)
88
+
89
+ return log_probs
90
+
91
+ # ==========================================
92
+ # 2. Preprocessing Functions
93
+ # ==========================================
94
+
95
+ def resize_and_pad(img, target_size=(80, 280)):
96
+ target_h, target_w = target_size
97
+ h, w = img.shape[:2]
98
+
99
+ scale = min(target_w / w, target_h / h)
100
+ new_w, new_h = int(w * scale), int(h * scale)
101
+ resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
102
+
103
+ padded = np.ones((target_h, target_w), dtype=img.dtype) * 255
104
+
105
+ x_offset = (target_w - new_w) // 2
106
+ y_offset = (target_h - new_h) // 2
107
+ padded[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized
108
+
109
+ return padded
110
+
111
+ def remove_black_lines(img):
112
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
113
+ lower_black = np.array([0, 0, 0])
114
+ upper_black = np.array([180, 255, 80])
115
+ mask_black = cv2.inRange(hsv, lower_black, upper_black)
116
+ cleaned = cv2.inpaint(img, mask_black, inpaintRadius=1, flags=cv2.INPAINT_TELEA)
117
+ return cleaned
118
+
119
+ def preprocess_image(image):
120
+ """Preprocess image for model inference"""
121
+ # Convert PIL to OpenCV format
122
+ img = np.array(image)
123
+
124
+ # If RGB, convert to BGR for OpenCV
125
+ if len(img.shape) == 3 and img.shape[2] == 3:
126
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
127
+
128
+ # Remove noise lines
129
+ img = remove_black_lines(img)
130
+
131
+ # Convert to grayscale
132
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
133
+
134
+ # Resize and pad
135
+ img = resize_and_pad(img, target_size=(80, 280))
136
+
137
+ # Normalize
138
+ img = img.astype('float32') / 255.0
139
+ img = torch.tensor(img).unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
140
+
141
+ return img
142
+
143
+ # ==========================================
144
+ # 3. Load Model & Character Mapping
145
+ # ==========================================
146
+
147
+ CHARS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
148
+ char_to_idx = {c: i + 1 for i, c in enumerate(CHARS)}
149
+ idx_to_char = {i + 1: c for i, c in enumerate(CHARS)}
150
+ idx_to_char[0] = "" # blank token
151
+
152
+ num_classes = len(CHARS) + 1
153
+
154
+ # Load model
155
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
156
+ model = CRNN(num_classes=num_classes).to(device)
157
+
158
+ # Load checkpoint (update path to your .pth file)
159
+ checkpoint = torch.load('best_model.pth', map_location=device)
160
+ model.load_state_dict(checkpoint['model_state_dict'])
161
+ model.eval()
162
+
163
+ print(f"✅ Model loaded successfully! Using device: {device}")
164
+
165
+ # ==========================================
166
+ # 4. Prediction Function
167
+ # ==========================================
168
+
169
+ def predict_captcha(image):
170
+ """Predict CAPTCHA text from image"""
171
+
172
+ # Preprocess
173
+ img_tensor = preprocess_image(image).to(device)
174
+
175
+ # Inference
176
+ with torch.no_grad():
177
+ log_probs = model(img_tensor)
178
+
179
+ # Greedy decoding
180
+ _, max_indices = torch.max(log_probs, dim=2)
181
+ max_indices = max_indices.squeeze(1).cpu().numpy()
182
+
183
+ # CTC collapse (remove blanks and repeated tokens)
184
+ collapsed = []
185
+ prev = None
186
+ for token in max_indices:
187
+ if token != 0 and token != prev:
188
+ collapsed.append(token)
189
+ prev = token
190
+
191
+ # Decode to text
192
+ prediction = ''.join([idx_to_char.get(t, '') for t in collapsed])
193
+
194
+ # Return with confidence info
195
+ return {
196
+ "Prediction": prediction,
197
+ "Length": len(prediction),
198
+ "Device": str(device)
199
+ }
200
+
201
+ # ==========================================
202
+ # 5. Gradio Interface
203
+ # ==========================================
204
+
205
+ demo = gr.Interface(
206
+ fn=predict_captcha,
207
+ inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"),
208
+ outputs=gr.JSON(label="Prediction Results"),
209
+ title="🔐 CAPTCHA Recognition System",
210
+ description="""
211
+ **CS4243 Mini Project - CAPTCHA Recognition using CRNN + CTC Loss**
212
+
213
+ Upload a CAPTCHA image to see the model's prediction.
214
+
215
+ **Model Architecture:**
216
+ - ResNet-based CNN feature extraction
217
+ - Bidirectional LSTM for sequence modeling
218
+ - CTC Loss for alignment-free training
219
+
220
+ **Performance:**
221
+ - Sequence Accuracy: ~54%
222
+ - Character Accuracy: ~86%
223
+ - Trained on 9,000 samples with heavy augmentation
224
+ """,
225
+ examples=[
226
+ # Add example image paths here if you want
227
+ # ["example1.png"],
228
+ # ["example2.png"],
229
+ ],
230
+ theme=gr.themes.Soft(),
231
+ allow_flagging="never"
232
+ )
233
+
234
+ if __name__ == "__main__":
235
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ opencv-python-headless
4
+ numpy
5
+ pillow
6
+ gradio