Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import string | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| from itertools import groupby | |
| # Device configuration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Constants | |
| IMG_HEIGHT = 32 | |
| IMG_WIDTH = 128 | |
| characters = string.ascii_letters + string.digits | |
| char_to_idx = {c: i for i, c in enumerate(characters)} | |
| idx_to_char = {i: c for i, c in enumerate(characters)} | |
| VOCAB_SIZE = len(characters) + 1 # +1 for CTC blank token | |
| # -------------------------- | |
| # Model Architecture (Same as Training) | |
| # -------------------------- | |
| class FastCRNN(nn.Module): | |
| def __init__(self, num_classes): | |
| super().__init__() | |
| resnet = models.resnet18(weights=None) | |
| resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| self.cnn = nn.Sequential(*list(resnet.children())[:-3]) # Output: [B, 256, 4, 16] | |
| self.lstm_input_size = 128 * (IMG_HEIGHT // 8) # 256 * 4 | |
| self.rnn = nn.LSTM(self.lstm_input_size, 256, num_layers=2, bidirectional=True, dropout=0.1) | |
| self.fc = nn.Linear(512, num_classes) | |
| def forward(self, x): | |
| x = self.cnn(x) | |
| x = x.permute(3, 0, 1, 2) # [W, B, C, H] | |
| x = x.contiguous().view(x.size(0), x.size(1), -1) # [W, B, C*H] | |
| x, _ = self.rnn(x) | |
| x = self.fc(x) | |
| return x | |
| # -------------------------- | |
| # Model Loading | |
| # -------------------------- | |
| def load_model(): | |
| model = FastCRNN(num_classes=VOCAB_SIZE).to(device) | |
| model.load_state_dict(torch.load('fast_crnn_captcha_model.pth', map_location=device)) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # -------------------------- | |
| # Prediction Logic | |
| # -------------------------- | |
| def decode_predictions(preds): | |
| """More robust CTC decoding""" | |
| preds = preds.permute(1, 0, 2) # [B, W, C] | |
| preds = torch.softmax(preds, dim=2) | |
| pred_indices = torch.argmax(preds, dim=2) | |
| texts = [] | |
| for pred in pred_indices: | |
| # Merge repeated and remove blank (VOCAB_SIZE-1) | |
| decoded = [] | |
| prev_char = None | |
| for idx in pred: | |
| char_idx = idx.item() | |
| if char_idx < len(idx_to_char) and char_idx != (VOCAB_SIZE - 1): | |
| char = idx_to_char[char_idx] | |
| if char != prev_char: | |
| decoded.append(char) | |
| prev_char = char | |
| texts.append(''.join(decoded)) | |
| return texts[0] if len(texts) == 1 else texts | |
| def preprocess_image(image): | |
| """Convert input to model-compatible format""" | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), | |
| transforms.Grayscale(), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| return transform(image).unsqueeze(0).to(device) | |
| def predict(image): | |
| try: | |
| if image is None: | |
| return "No image provided" | |
| # Optional: handle Gradio dictionary format | |
| if isinstance(image, dict) and 'data' in image: | |
| image = image['data'] | |
| # Convert to PIL Image if not already | |
| if isinstance(image, str) and image.startswith('data:image'): | |
| from io import BytesIO | |
| import base64 | |
| image_data = base64.b64decode(image.split(',')[1]) | |
| image = Image.open(BytesIO(image_data)) | |
| elif not isinstance(image, Image.Image): | |
| image = Image.open(BytesIO(image)) | |
| # Preprocessing (must match training) | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_HEIGHT, IMG_WIDTH)), # 32x128 | |
| transforms.Grayscale(), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| prediction = decode_predictions(output) | |
| print(f"Predicted text: {prediction}") | |
| return prediction | |
| except Exception as e: | |
| print(f"Error details: {str(e)}") | |
| return f"Error processing image: {str(e)}" | |
| # -------------------------- | |
| # Gradio Interface | |
| # -------------------------- | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"), | |
| outputs=gr.Textbox(label="Predicted Text"), | |
| title="CAPTCHA Solver (FastCRNN)", | |
| description="Upload a CAPTCHA image to extract text using ResNet18 + BiLSTM" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch( | |
| share=True | |
| ) |