Aff77 commited on
Commit
41fa3a5
·
verified ·
1 Parent(s): 5a255e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -162
app.py CHANGED
@@ -1,163 +1,182 @@
1
- import gradio as gr
2
- import torch
3
- import string
4
- from PIL import Image
5
- import torchvision.transforms as transforms
6
- from torch import nn
7
- import torch.nn.functional as F
8
- from torchvision import models
9
-
10
- # Device configuration
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- # Character mapping
14
- characters = string.ascii_letters + string.digits
15
- char_to_idx = {c: i for i, c in enumerate(characters)}
16
- idx_to_char = {i: c for i, c in enumerate(characters)}
17
- VOCAB_SIZE = len(characters) + 1 # +1 for blank token for CTC
18
- MAX_LABEL_LENGTH = 6
19
- IMG_HEIGHT = 32
20
- IMG_WIDTH = 128
21
-
22
- # --------------------------
23
- # Your Model Architecture
24
- # --------------------------
25
- class STN(nn.Module):
26
- def __init__(self):
27
- super().__init__()
28
- self.localization = nn.Sequential(
29
- nn.Conv2d(1, 8, kernel_size=7),
30
- nn.MaxPool2d(2, stride=2),
31
- nn.ReLU(True),
32
- nn.Conv2d(8, 10, kernel_size=5),
33
- nn.MaxPool2d(2, stride=2),
34
- nn.ReLU(True)
35
- )
36
-
37
- # Calculate flattened size
38
- with torch.no_grad():
39
- dummy = torch.zeros(1, 1, IMG_HEIGHT, IMG_WIDTH)
40
- out = self.localization(dummy)
41
- self.flat_size = out.view(1, -1).shape[1]
42
-
43
- self.fc_loc = nn.Sequential(
44
- nn.Linear(self.flat_size, 32),
45
- nn.ReLU(True),
46
- nn.Linear(32, 6)
47
- )
48
-
49
- # Initialize as identity transform
50
- self.fc_loc[2].weight.data.zero_()
51
- self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
52
-
53
- def forward(self, x):
54
- xs = self.localization(x)
55
- xs = xs.view(xs.size(0), -1)
56
- theta = self.fc_loc(xs)
57
- theta = theta.view(-1, 2, 3)
58
- grid = F.affine_grid(theta, x.size(), align_corners=False)
59
- x = F.grid_sample(x, grid, align_corners=False)
60
- return x
61
-
62
- class FastCRNN(nn.Module):
63
- def __init__(self, num_classes):
64
- super().__init__()
65
- self.stn = STN()
66
- resnet = models.resnet18(pretrained=False)
67
- resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
68
- self.cnn = nn.Sequential(*list(resnet.children())[:-3])
69
-
70
- self.lstm_input_size = 128 * (IMG_HEIGHT // 8)
71
- self.rnn = nn.LSTM(self.lstm_input_size, 256, num_layers=2, bidirectional=True, dropout=0.1)
72
- self.fc = nn.Linear(512, num_classes)
73
-
74
- def forward(self, x):
75
- x = self.stn(x)
76
- x = self.cnn(x)
77
- x = x.permute(3, 0, 1, 2)
78
- x = x.contiguous().view(x.size(0), x.size(1), -1)
79
- x, _ = self.rnn(x)
80
- x = self.fc(x)
81
- return x
82
-
83
- # --------------------------
84
- # Helper Functions
85
- # --------------------------
86
- def decode_predictions(preds):
87
- """Convert model output to text using CTC decoding"""
88
- preds = preds.permute(1, 0, 2) # [B, W, C]
89
- preds = torch.softmax(preds, dim=2)
90
- preds = torch.argmax(preds, dim=2)
91
-
92
- texts = []
93
- for pred in preds:
94
- # CTC decoding: merge repeated and remove blank (if needed)
95
- decoded = []
96
- prev_char = -1
97
- for char in pred:
98
- if char != prev_char and char < len(characters): # Skip blank if present
99
- decoded.append(char.item())
100
- prev_char = char
101
- text = ''.join([idx_to_char[c] for c in decoded])
102
- texts.append(text)
103
- return texts[0] if len(texts) == 1 else texts
104
-
105
- # --------------------------
106
- # Model Loading
107
- # --------------------------
108
- def load_model():
109
- model = FastCRNN(num_classes=VOCAB_SIZE).to(device)
110
- model.load_state_dict(torch.load('model/fast_crnn.pth', map_location=device))
111
- model.eval()
112
- return model
113
-
114
- model = load_model()
115
-
116
- # --------------------------
117
- # Prediction Function
118
- # --------------------------
119
- def predict_captcha(image):
120
- try:
121
- # Convert Gradio input to PIL Image
122
- if isinstance(image, dict): # Gradio might pass a dict
123
- image = image['image'] if 'image' in image else image['data']
124
- if not isinstance(image, Image.Image):
125
- image = Image.fromarray(image) if isinstance(image, np.ndarray) else Image.open(image)
126
-
127
- # Preprocess
128
- transform = transforms.Compose([
129
- transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
130
- transforms.Grayscale(),
131
- transforms.ToTensor(),
132
- transforms.Normalize((0.5,), (0.5,))
133
- ])
134
-
135
- image_tensor = transform(image).unsqueeze(0).to(device)
136
-
137
- # Predict
138
- with torch.no_grad():
139
- outputs = model(image_tensor)
140
- prediction = decode_predictions(outputs)
141
-
142
- return prediction
143
-
144
- except Exception as e:
145
- return f"Error: {str(e)}"
146
-
147
- # --------------------------
148
- # Gradio Interface
149
- # --------------------------
150
- iface = gr.Interface(
151
- fn=predict_captcha,
152
- inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"),
153
- outputs=gr.Textbox(label="Predicted Text"),
154
- title="CAPTCHA Recognition with FastCRNN",
155
- description="Upload a CAPTCHA image to get the predicted text.",
156
- examples=[
157
- ["examples/example1.png"],
158
- ["examples/example2.png"]
159
- ]
160
- )
161
-
162
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  iface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import string
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ # Device configuration
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Constants
13
+ VOCAB_SIZE = 26*2 + 10 # Letters (upper/lower) + digits
14
+ OUTPUT_LENGTH = 5 # 5-character CAPTCHAs
15
+ AFFN_KERNEL = 5
16
+ AFFN_STRIDE = 1
17
+ AFFN_DEPTH = 4
18
+ CRNN_KERNEL = 5
19
+ CRNN_POOL_KERNEL = 2
20
+ CRNN_DROPOUT = 0.3
21
+ CRNN_LATENT = 128
22
+ LSTM_HIDDEN_DIM = 32
23
+
24
+ # Character mapping
25
+ characters = string.ascii_letters + string.digits
26
+ idx_to_char = {i: c for i, c in enumerate(characters)}
27
+
28
+ # --------------------------
29
+ # Original Model Architecture (CRNN+AFFN)
30
+ # --------------------------
31
+ class Encoder(nn.Sequential):
32
+ def __init__(self, n, kernel_size, stride):
33
+ super().__init__(
34
+ nn.Conv2d(4**(n-1), 4**n, kernel_size, stride),
35
+ nn.BatchNorm2d(4**n),
36
+ nn.ReLU()
37
+ )
38
+
39
+ class Decoder(nn.Sequential):
40
+ def __init__(self, n, kernel_size, stride):
41
+ super().__init__(
42
+ nn.ConvTranspose2d(4**n, 4**(n-1), kernel_size, stride),
43
+ nn.BatchNorm2d(4**(n-1)),
44
+ nn.ReLU()
45
+ )
46
+
47
+ class AFFN(nn.Module):
48
+ def __init__(self, n):
49
+ super().__init__()
50
+ self.n = n
51
+ self.alpha = nn.Parameter(torch.randn(n-1))
52
+ self.encoders = nn.ModuleList([Encoder(i, AFFN_KERNEL, AFFN_STRIDE) for i in range(1, n+1)])
53
+ self.decoders = nn.ModuleList([Decoder(i, AFFN_KERNEL, AFFN_STRIDE) for i in range(n, 0, -1)])
54
+
55
+ def forward(self, x):
56
+ residuals = []
57
+ for i, enc in enumerate(self.encoders):
58
+ x = enc(x)
59
+ if i < self.n - 1:
60
+ x = x * (1 - self.alpha[i])
61
+ residuals.append(x * self.alpha[i])
62
+
63
+ for i, dec in enumerate(self.decoders):
64
+ x = dec(x)
65
+ if i < self.n - 1:
66
+ x = x + residuals.pop()
67
+ return x
68
+
69
+ class CRNN(nn.Module):
70
+ def __init__(self):
71
+ super().__init__()
72
+ self.conv1 = nn.Sequential(
73
+ nn.Conv2d(64, 128, CRNN_KERNEL, padding=2),
74
+ nn.BatchNorm2d(128),
75
+ nn.ReLU(),
76
+ nn.MaxPool2d(CRNN_POOL_KERNEL)
77
+ )
78
+ self.conv2 = nn.Sequential(
79
+ nn.Conv2d(128, 256, CRNN_KERNEL, padding=2),
80
+ nn.BatchNorm2d(256),
81
+ nn.ReLU(),
82
+ nn.MaxPool2d(CRNN_POOL_KERNEL)
83
+ )
84
+ self.flatten = nn.Flatten()
85
+ self.dropout = nn.Dropout(CRNN_DROPOUT)
86
+ self.latent_fc = nn.LazyLinear(CRNN_LATENT)
87
+ self.lstm = nn.LSTM(CRNN_LATENT, LSTM_HIDDEN_DIM, batch_first=True)
88
+ self.output_fc = nn.Linear(LSTM_HIDDEN_DIM, VOCAB_SIZE)
89
+
90
+ def forward(self, x):
91
+ x = self.conv1(x)
92
+ x = self.conv2(x)
93
+ x = self.flatten(x)
94
+ x = self.dropout(x)
95
+ x = self.latent_fc(x)
96
+ x = x.unsqueeze(1)
97
+ lstm_out, _ = self.lstm(x)
98
+ return self.output_fc(lstm_out.squeeze(1))
99
+
100
+ class CaptchaCrackNet(nn.Module):
101
+ def __init__(self):
102
+ super().__init__()
103
+ self.affn = AFFN(AFFN_DEPTH)
104
+ self.conv1 = nn.Sequential(
105
+ nn.Conv2d(1, 32, 5, padding=2),
106
+ nn.ReLU(),
107
+ nn.MaxPool2d(2)
108
+ )
109
+ self.conv2 = nn.Sequential(
110
+ nn.Conv2d(32, 48, 5, padding=2),
111
+ nn.ReLU(),
112
+ nn.MaxPool2d(2)
113
+ )
114
+ self.conv3 = nn.Sequential(
115
+ nn.Conv2d(48, 64, 5, padding=2),
116
+ nn.ReLU(),
117
+ nn.MaxPool2d(2)
118
+ )
119
+ self.res = nn.Conv2d(1, 32, 5, stride=2, padding=2)
120
+ self.crnn = CRNN()
121
+
122
+ def forward(self, x):
123
+ x = self.affn(x)
124
+ res_out = self.res(x)
125
+ x = self.conv1(x)
126
+ x = self.conv2(x + res_out)
127
+ x = self.conv3(x)
128
+ return self.crnn(x)
129
+
130
+ # --------------------------
131
+ # Model Loading
132
+ # --------------------------
133
+ def load_model():
134
+ model = CaptchaCrackNet().to(device)
135
+ model.load_state_dict(torch.load('final.pth', map_location=device))
136
+ model.eval()
137
+ return model
138
+
139
+ model = load_model()
140
+
141
+ # --------------------------
142
+ # Prediction Logic
143
+ # --------------------------
144
+ def to_text(pred):
145
+ return ''.join([idx_to_char[i] for i in pred.argmax(dim=1)])
146
+
147
+ def predict(image):
148
+ try:
149
+ # Preprocess
150
+ transform = transforms.Compose([
151
+ transforms.Resize((40, 150)),
152
+ transforms.Grayscale(),
153
+ transforms.ToTensor(),
154
+ transforms.Normalize((0.5,), (0.5,))
155
+ ])
156
+
157
+ img_tensor = transform(image).unsqueeze(0).to(device)
158
+
159
+ # Predict
160
+ with torch.no_grad():
161
+ output = model(img_tensor)
162
+ return to_text(output.squeeze(0))
163
+
164
+ except Exception as e:
165
+ return f"Error: {str(e)}"
166
+
167
+ # --------------------------
168
+ # Gradio Interface
169
+ # --------------------------
170
+ iface = gr.Interface(
171
+ fn=predict,
172
+ inputs=gr.Image(type="pil", label="Upload CAPTCHA"),
173
+ outputs=gr.Textbox(label="Predicted Text"),
174
+ title="CAPTCHA CrackNet",
175
+ examples=[
176
+ ["examples/example1.png"],
177
+ ["examples/example2.png"]
178
+ ]
179
+ )
180
+
181
+ if __name__ == "__main__":
182
  iface.launch()