Aff77 commited on
Commit
e23252d
·
verified ·
1 Parent(s): 41fa3a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -126
app.py CHANGED
@@ -5,134 +5,48 @@ from PIL import Image
5
  import torchvision.transforms as transforms
6
  from torch import nn
7
  import torch.nn.functional as F
 
 
8
 
9
  # Device configuration
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Constants
13
- VOCAB_SIZE = 26*2 + 10 # Letters (upper/lower) + digits
14
- OUTPUT_LENGTH = 5 # 5-character CAPTCHAs
15
- AFFN_KERNEL = 5
16
- AFFN_STRIDE = 1
17
- AFFN_DEPTH = 4
18
- CRNN_KERNEL = 5
19
- CRNN_POOL_KERNEL = 2
20
- CRNN_DROPOUT = 0.3
21
- CRNN_LATENT = 128
22
- LSTM_HIDDEN_DIM = 32
23
-
24
- # Character mapping
25
  characters = string.ascii_letters + string.digits
 
26
  idx_to_char = {i: c for i, c in enumerate(characters)}
 
27
 
28
  # --------------------------
29
- # Original Model Architecture (CRNN+AFFN)
30
  # --------------------------
31
- class Encoder(nn.Sequential):
32
- def __init__(self, n, kernel_size, stride):
33
- super().__init__(
34
- nn.Conv2d(4**(n-1), 4**n, kernel_size, stride),
35
- nn.BatchNorm2d(4**n),
36
- nn.ReLU()
37
- )
38
-
39
- class Decoder(nn.Sequential):
40
- def __init__(self, n, kernel_size, stride):
41
- super().__init__(
42
- nn.ConvTranspose2d(4**n, 4**(n-1), kernel_size, stride),
43
- nn.BatchNorm2d(4**(n-1)),
44
- nn.ReLU()
45
- )
46
-
47
- class AFFN(nn.Module):
48
- def __init__(self, n):
49
  super().__init__()
50
- self.n = n
51
- self.alpha = nn.Parameter(torch.randn(n-1))
52
- self.encoders = nn.ModuleList([Encoder(i, AFFN_KERNEL, AFFN_STRIDE) for i in range(1, n+1)])
53
- self.decoders = nn.ModuleList([Decoder(i, AFFN_KERNEL, AFFN_STRIDE) for i in range(n, 0, -1)])
54
-
55
- def forward(self, x):
56
- residuals = []
57
- for i, enc in enumerate(self.encoders):
58
- x = enc(x)
59
- if i < self.n - 1:
60
- x = x * (1 - self.alpha[i])
61
- residuals.append(x * self.alpha[i])
62
 
63
- for i, dec in enumerate(self.decoders):
64
- x = dec(x)
65
- if i < self.n - 1:
66
- x = x + residuals.pop()
67
- return x
68
-
69
- class CRNN(nn.Module):
70
- def __init__(self):
71
- super().__init__()
72
- self.conv1 = nn.Sequential(
73
- nn.Conv2d(64, 128, CRNN_KERNEL, padding=2),
74
- nn.BatchNorm2d(128),
75
- nn.ReLU(),
76
- nn.MaxPool2d(CRNN_POOL_KERNEL)
77
- )
78
- self.conv2 = nn.Sequential(
79
- nn.Conv2d(128, 256, CRNN_KERNEL, padding=2),
80
- nn.BatchNorm2d(256),
81
- nn.ReLU(),
82
- nn.MaxPool2d(CRNN_POOL_KERNEL)
83
- )
84
- self.flatten = nn.Flatten()
85
- self.dropout = nn.Dropout(CRNN_DROPOUT)
86
- self.latent_fc = nn.LazyLinear(CRNN_LATENT)
87
- self.lstm = nn.LSTM(CRNN_LATENT, LSTM_HIDDEN_DIM, batch_first=True)
88
- self.output_fc = nn.Linear(LSTM_HIDDEN_DIM, VOCAB_SIZE)
89
-
90
- def forward(self, x):
91
- x = self.conv1(x)
92
- x = self.conv2(x)
93
- x = self.flatten(x)
94
- x = self.dropout(x)
95
- x = self.latent_fc(x)
96
- x = x.unsqueeze(1)
97
- lstm_out, _ = self.lstm(x)
98
- return self.output_fc(lstm_out.squeeze(1))
99
-
100
- class CaptchaCrackNet(nn.Module):
101
- def __init__(self):
102
- super().__init__()
103
- self.affn = AFFN(AFFN_DEPTH)
104
- self.conv1 = nn.Sequential(
105
- nn.Conv2d(1, 32, 5, padding=2),
106
- nn.ReLU(),
107
- nn.MaxPool2d(2)
108
- )
109
- self.conv2 = nn.Sequential(
110
- nn.Conv2d(32, 48, 5, padding=2),
111
- nn.ReLU(),
112
- nn.MaxPool2d(2)
113
- )
114
- self.conv3 = nn.Sequential(
115
- nn.Conv2d(48, 64, 5, padding=2),
116
- nn.ReLU(),
117
- nn.MaxPool2d(2)
118
- )
119
- self.res = nn.Conv2d(1, 32, 5, stride=2, padding=2)
120
- self.crnn = CRNN()
121
 
122
  def forward(self, x):
123
- x = self.affn(x)
124
- res_out = self.res(x)
125
- x = self.conv1(x)
126
- x = self.conv2(x + res_out)
127
- x = self.conv3(x)
128
- return self.crnn(x)
129
 
130
  # --------------------------
131
  # Model Loading
132
  # --------------------------
133
  def load_model():
134
- model = CaptchaCrackNet().to(device)
135
- model.load_state_dict(torch.load('final.pth', map_location=device))
136
  model.eval()
137
  return model
138
 
@@ -141,26 +55,50 @@ model = load_model()
141
  # --------------------------
142
  # Prediction Logic
143
  # --------------------------
144
- def to_text(pred):
145
- return ''.join([idx_to_char[i] for i in pred.argmax(dim=1)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  def predict(image):
148
  try:
149
- # Preprocess
150
- transform = transforms.Compose([
151
- transforms.Resize((40, 150)),
152
- transforms.Grayscale(),
153
- transforms.ToTensor(),
154
- transforms.Normalize((0.5,), (0.5,))
155
- ])
156
 
157
- img_tensor = transform(image).unsqueeze(0).to(device)
158
-
159
- # Predict
160
  with torch.no_grad():
161
- output = model(img_tensor)
162
- return to_text(output.squeeze(0))
163
-
 
 
164
  except Exception as e:
165
  return f"Error: {str(e)}"
166
 
@@ -169,9 +107,10 @@ def predict(image):
169
  # --------------------------
170
  iface = gr.Interface(
171
  fn=predict,
172
- inputs=gr.Image(type="pil", label="Upload CAPTCHA"),
173
  outputs=gr.Textbox(label="Predicted Text"),
174
- title="CAPTCHA CrackNet",
 
175
  examples=[
176
  ["examples/example1.png"],
177
  ["examples/example2.png"]
 
5
  import torchvision.transforms as transforms
6
  from torch import nn
7
  import torch.nn.functional as F
8
+ from torchvision import models
9
+ from itertools import groupby
10
 
11
  # Device configuration
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  # Constants
15
+ IMG_HEIGHT = 32
16
+ IMG_WIDTH = 128
 
 
 
 
 
 
 
 
 
 
17
  characters = string.ascii_letters + string.digits
18
+ char_to_idx = {c: i for i, c in enumerate(characters)}
19
  idx_to_char = {i: c for i, c in enumerate(characters)}
20
+ VOCAB_SIZE = len(characters) + 1 # +1 for CTC blank token
21
 
22
  # --------------------------
23
+ # Model Architecture (Same as Training)
24
  # --------------------------
25
+ class FastCRNN(nn.Module):
26
+ def __init__(self, num_classes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  super().__init__()
28
+ resnet = models.resnet18(pretrained=False)
29
+ resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
30
+ self.cnn = nn.Sequential(*list(resnet.children())[:-3]) # Output: [B, 256, 4, 16]
 
 
 
 
 
 
 
 
 
31
 
32
+ self.lstm_input_size = 128 * (IMG_HEIGHT // 8) # 256 * 4
33
+ self.rnn = nn.LSTM(self.lstm_input_size, 256, num_layers=2, bidirectional=True, dropout=0.1)
34
+ self.fc = nn.Linear(512, num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def forward(self, x):
37
+ x = self.cnn(x)
38
+ x = x.permute(3, 0, 1, 2) # [W, B, C, H]
39
+ x = x.contiguous().view(x.size(0), x.size(1), -1) # [W, B, C*H]
40
+ x, _ = self.rnn(x)
41
+ x = self.fc(x)
42
+ return x
43
 
44
  # --------------------------
45
  # Model Loading
46
  # --------------------------
47
  def load_model():
48
+ model = FastCRNN(num_classes=VOCAB_SIZE).to(device)
49
+ model.load_state_dict(torch.load('fast_crnn_captcha_model.pth', map_location=device))
50
  model.eval()
51
  return model
52
 
 
55
  # --------------------------
56
  # Prediction Logic
57
  # --------------------------
58
+ def decode_predictions(preds):
59
+ """Convert model output to text using CTC decoding"""
60
+ preds = preds.permute(1, 0, 2) # [B, W, C]
61
+ _, pred_indices = preds.max(2)
62
+
63
+ texts = []
64
+ for pred in pred_indices:
65
+ # CTC decoding: merge repeated and remove blank
66
+ decoded = []
67
+ prev_char = None
68
+ for idx in pred:
69
+ char = idx_to_char.get(idx.item(), '')
70
+ if char != prev_char and char != '' and idx.item() != (VOCAB_SIZE - 1):
71
+ decoded.append(char)
72
+ prev_char = char
73
+ texts.append(''.join(decoded))
74
+ return texts[0] if len(texts) == 1 else texts
75
+
76
+ def preprocess_image(image):
77
+ """Convert input to model-compatible format"""
78
+ transform = transforms.Compose([
79
+ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
80
+ transforms.Grayscale(),
81
+ transforms.ToTensor(),
82
+ transforms.Normalize((0.5,), (0.5,))
83
+ ])
84
+ return transform(image).unsqueeze(0).to(device)
85
 
86
  def predict(image):
87
  try:
88
+ # Handle Gradio input types
89
+ if isinstance(image, dict):
90
+ image = image['image'] if 'image' in image else image['data']
91
+ if not isinstance(image, Image.Image):
92
+ image = Image.fromarray(image)
 
 
93
 
94
+ # Process and predict
95
+ image_tensor = preprocess_image(image)
 
96
  with torch.no_grad():
97
+ outputs = model(image_tensor)
98
+ prediction = decode_predictions(outputs)
99
+
100
+ return prediction
101
+
102
  except Exception as e:
103
  return f"Error: {str(e)}"
104
 
 
107
  # --------------------------
108
  iface = gr.Interface(
109
  fn=predict,
110
+ inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"),
111
  outputs=gr.Textbox(label="Predicted Text"),
112
+ title="CAPTCHA Solver (FastCRNN)",
113
+ description="Upload a CAPTCHA image to extract text using ResNet18 + BiLSTM",
114
  examples=[
115
  ["examples/example1.png"],
116
  ["examples/example2.png"]