import gradio as gr import torch import string from PIL import Image import torchvision.transforms as transforms from torch import nn import torch.nn.functional as F from torchvision import models from itertools import groupby # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Constants IMG_HEIGHT = 32 IMG_WIDTH = 128 characters = string.ascii_letters + string.digits char_to_idx = {c: i for i, c in enumerate(characters)} idx_to_char = {i: c for i, c in enumerate(characters)} VOCAB_SIZE = len(characters) + 1 # +1 for CTC blank token # -------------------------- # Model Architecture (Same as Training) # -------------------------- class FastCRNN(nn.Module): def __init__(self, num_classes): super().__init__() resnet = models.resnet18(weights=None) resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) self.cnn = nn.Sequential(*list(resnet.children())[:-3]) # Output: [B, 256, 4, 16] self.lstm_input_size = 128 * (IMG_HEIGHT // 8) # 256 * 4 self.rnn = nn.LSTM(self.lstm_input_size, 256, num_layers=2, bidirectional=True, dropout=0.1) self.fc = nn.Linear(512, num_classes) def forward(self, x): x = self.cnn(x) x = x.permute(3, 0, 1, 2) # [W, B, C, H] x = x.contiguous().view(x.size(0), x.size(1), -1) # [W, B, C*H] x, _ = self.rnn(x) x = self.fc(x) return x # -------------------------- # Model Loading # -------------------------- def load_model(): model = FastCRNN(num_classes=VOCAB_SIZE).to(device) model.load_state_dict(torch.load('fast_crnn_captcha_model.pth', map_location=device)) model.eval() return model model = load_model() # -------------------------- # Prediction Logic # -------------------------- def decode_predictions(preds): """More robust CTC decoding""" preds = preds.permute(1, 0, 2) # [B, W, C] preds = torch.softmax(preds, dim=2) pred_indices = torch.argmax(preds, dim=2) texts = [] for pred in pred_indices: # Merge repeated and remove blank (VOCAB_SIZE-1) decoded = [] prev_char = None for idx in pred: char_idx = idx.item() if char_idx < len(idx_to_char) and char_idx != (VOCAB_SIZE - 1): char = idx_to_char[char_idx] if char != prev_char: decoded.append(char) prev_char = char texts.append(''.join(decoded)) return texts[0] if len(texts) == 1 else texts def preprocess_image(image): """Convert input to model-compatible format""" transform = transforms.Compose([ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) return transform(image).unsqueeze(0).to(device) def predict(image): try: if image is None: return "No image provided" # Optional: handle Gradio dictionary format if isinstance(image, dict) and 'data' in image: image = image['data'] # Convert to PIL Image if not already if isinstance(image, str) and image.startswith('data:image'): from io import BytesIO import base64 image_data = base64.b64decode(image.split(',')[1]) image = Image.open(BytesIO(image_data)) elif not isinstance(image, Image.Image): image = Image.open(BytesIO(image)) # Preprocessing (must match training) transform = transforms.Compose([ transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), # 32x128 transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image_tensor) prediction = decode_predictions(output) print(f"Predicted text: {prediction}") return prediction except Exception as e: print(f"Error details: {str(e)}") return f"Error processing image: {str(e)}" # -------------------------- # Gradio Interface # -------------------------- iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"), outputs=gr.Textbox(label="Predicted Text"), title="CAPTCHA Solver (FastCRNN)", description="Upload a CAPTCHA image to extract text using ResNet18 + BiLSTM" ) if __name__ == "__main__": iface.launch( share=True )