captcha_stn / app.py
Aff77's picture
Update app.py
79321e5 verified
import gradio as gr
import torch
import string
from PIL import Image
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
from torchvision import models
from itertools import groupby
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Constants
IMG_HEIGHT = 32
IMG_WIDTH = 128
characters = string.ascii_letters + string.digits
char_to_idx = {c: i for i, c in enumerate(characters)}
idx_to_char = {i: c for i, c in enumerate(characters)}
VOCAB_SIZE = len(characters) + 1 # +1 for CTC blank token
# --------------------------
# Model Architecture (Same as Training)
# --------------------------
class FastCRNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
resnet = models.resnet18(weights=None)
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.cnn = nn.Sequential(*list(resnet.children())[:-3]) # Output: [B, 256, 4, 16]
self.lstm_input_size = 128 * (IMG_HEIGHT // 8) # 256 * 4
self.rnn = nn.LSTM(self.lstm_input_size, 256, num_layers=2, bidirectional=True, dropout=0.1)
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.cnn(x)
x = x.permute(3, 0, 1, 2) # [W, B, C, H]
x = x.contiguous().view(x.size(0), x.size(1), -1) # [W, B, C*H]
x, _ = self.rnn(x)
x = self.fc(x)
return x
# --------------------------
# Model Loading
# --------------------------
def load_model():
model = FastCRNN(num_classes=VOCAB_SIZE).to(device)
model.load_state_dict(torch.load('fast_crnn_captcha_model.pth', map_location=device))
model.eval()
return model
model = load_model()
# --------------------------
# Prediction Logic
# --------------------------
def decode_predictions(preds):
"""More robust CTC decoding"""
preds = preds.permute(1, 0, 2) # [B, W, C]
preds = torch.softmax(preds, dim=2)
pred_indices = torch.argmax(preds, dim=2)
texts = []
for pred in pred_indices:
# Merge repeated and remove blank (VOCAB_SIZE-1)
decoded = []
prev_char = None
for idx in pred:
char_idx = idx.item()
if char_idx < len(idx_to_char) and char_idx != (VOCAB_SIZE - 1):
char = idx_to_char[char_idx]
if char != prev_char:
decoded.append(char)
prev_char = char
texts.append(''.join(decoded))
return texts[0] if len(texts) == 1 else texts
def preprocess_image(image):
"""Convert input to model-compatible format"""
transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
return transform(image).unsqueeze(0).to(device)
def predict(image):
try:
if image is None:
return "No image provided"
# Optional: handle Gradio dictionary format
if isinstance(image, dict) and 'data' in image:
image = image['data']
# Convert to PIL Image if not already
if isinstance(image, str) and image.startswith('data:image'):
from io import BytesIO
import base64
image_data = base64.b64decode(image.split(',')[1])
image = Image.open(BytesIO(image_data))
elif not isinstance(image, Image.Image):
image = Image.open(BytesIO(image))
# Preprocessing (must match training)
transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), # 32x128
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
prediction = decode_predictions(output)
print(f"Predicted text: {prediction}")
return prediction
except Exception as e:
print(f"Error details: {str(e)}")
return f"Error processing image: {str(e)}"
# --------------------------
# Gradio Interface
# --------------------------
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"),
outputs=gr.Textbox(label="Predicted Text"),
title="CAPTCHA Solver (FastCRNN)",
description="Upload a CAPTCHA image to extract text using ResNet18 + BiLSTM"
)
if __name__ == "__main__":
iface.launch(
share=True
)