import torch import gradio as gr from torch import nn import torch.nn.functional as F import os from torchvision import transforms from torch.utils.data import DataLoader,random_split,Dataset from PIL import Image,UnidentifiedImageError import string import matplotlib.pyplot as plt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") characters = string.ascii_letters + string.digits idx_to_char = {idx: char for idx, char in enumerate(characters)} #AFFN AFFN_KERNEL=5 AFFN_STRIDE=1 AFFN_DEPTH=1 #CRNN CRNN_KERNEL=5 CRNN_POOL_KERNEL=2 CRNN_DROPOUT=0.3 CRNN_LATENT=128 LSTM_HIDDEN_DIM=32 VOCAB_SIZE=26*2+10 OUTPUT_LENGTH=5 class Encoder(nn.Sequential): def __init__(self,n,kernel_size,stride): super().__init__( nn.Conv2d(in_channels=4**(n-1),out_channels=4**n,kernel_size=kernel_size,stride=stride), nn.BatchNorm2d(num_features=4**n), nn.ReLU(inplace=False) ) class Decoder(nn.Sequential): def __init__(self,n,kernel_size,stride): super().__init__( nn.ConvTranspose2d(in_channels=4**n,out_channels=4**(n-1),kernel_size=kernel_size,stride=stride), nn.BatchNorm2d(num_features=4**(n-1)), nn.ReLU(inplace=False) ) class AFFN(nn.Module): def __init__(self,n): super().__init__() self.n= n self.alpha = nn.Parameter(torch.randn(n-1).to(device)).to(device) self.encoders = [] self.decoders = [] for i in range(1,n+1): self.encoders.append(Encoder(i,AFFN_KERNEL,AFFN_STRIDE).to(device)) for i in range(n,0,-1): self.decoders.append(Decoder(i,AFFN_KERNEL,AFFN_STRIDE).to(device)) def forward(self,x): residuals = [] for i,enc in enumerate(self.encoders): x= enc(x) if i < self.n-1: x = x * (1 - self.alpha[i]) residuals.append(x * self.alpha[i]) for i,dec in enumerate(self.decoders): x= dec(x) if i < self.n-1: x= x + residuals.pop() return x class CRNN(nn.Module): def __init__(self, in_channels, kernel_size, pool_kernel_size, dropout, latent_dim, lstm_hidden_dim, vocab_size, output_length=5): super().__init__() self.lstm_hidden_dim = lstm_hidden_dim self.output_length = output_length # Should be 5 for 5 characters self.vocab_size = vocab_size self.conv1 = nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, kernel_size=kernel_size, padding=2), nn.BatchNorm2d(num_features=in_channels*2), nn.ReLU(inplace=False), nn.MaxPool2d(kernel_size=pool_kernel_size) ) self.conv2 = nn.Sequential( nn.Conv2d(in_channels=in_channels*2, out_channels=in_channels*4, kernel_size=kernel_size, padding=2), nn.BatchNorm2d(num_features=in_channels*4), nn.ReLU(inplace=False), nn.MaxPool2d(kernel_size=pool_kernel_size) ) self.flatten = nn.Flatten() self.dropout = nn.Dropout(dropout) self.latent_fc = nn.LazyLinear(latent_dim) self.lstm = nn.LSTM(input_size=latent_dim, hidden_size=lstm_hidden_dim, num_layers=1, batch_first=True) self.output_fc = nn.Linear(lstm_hidden_dim, vocab_size) def forward(self, x): batch_size = x.size(0) # CNN feature extraction conv1_out = self.conv1(x) conv2_out = self.conv2(conv1_out) flattened = self.flatten(conv2_out) dropped = self.dropout(flattened) latent = self.latent_fc(dropped) lstm_input = latent.unsqueeze(1) # Initialize hidden and cell states h0 = torch.zeros(1, batch_size, self.lstm_hidden_dim, device=x.device) c0 = torch.zeros(1, batch_size, self.lstm_hidden_dim, device=x.device) outputs = [] # Generate 5 characters sequentially for _ in range(self.output_length): out, (h0, c0) = self.lstm(lstm_input, (h0, c0)) # out shape: (batch_size, 1, lstm_hidden_dim) logits = self.output_fc(out.squeeze(1)) # Shape: (batch_size, vocab_size) outputs.append(logits) outputs = torch.stack(outputs, dim=1) # Shape: (batch_size, 5, vocab_size) return outputs output=CRNN(64,CRNN_KERNEL,CRNN_POOL_KERNEL,CRNN_DROPOUT,CRNN_LATENT,LSTM_HIDDEN_DIM,VOCAB_SIZE,OUTPUT_LENGTH).to(device)(torch.zeros((2,64,256,256)).to(device)) class CaptchaCrackNet(nn.Module): def __init__(self): super().__init__() self.affn=AFFN(AFFN_DEPTH).to(device) self.conv1=nn.Sequential( nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5,padding=2), nn.ReLU(inplace=False), nn.MaxPool2d(kernel_size=2) ) self.conv2=nn.Sequential( nn.Conv2d(in_channels=32,out_channels=48,kernel_size=5,padding=2), nn.ReLU(inplace=False), nn.MaxPool2d(kernel_size=2) ) self.conv3=nn.Sequential( nn.Conv2d(in_channels=48,out_channels=64,kernel_size=5,padding=2), nn.ReLU(inplace=False), nn.MaxPool2d(kernel_size=2) ) self.res=nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=2, padding=2) self.crnn=CRNN(64,CRNN_KERNEL,CRNN_POOL_KERNEL,CRNN_DROPOUT,CRNN_LATENT,LSTM_HIDDEN_DIM,VOCAB_SIZE,OUTPUT_LENGTH).to(device) def forward(self,x): affn_out=self.affn(x) res_out=self.res(x) conv1_out=self.conv1(affn_out) conv2_out=self.conv2(conv1_out+res_out) conv3_out=self.conv3(conv2_out) output=self.crnn(conv3_out) return output torch.manual_seed(42) model=CaptchaCrackNet().to(device) optimizer=torch.optim.Adam(model.parameters()) characters = string.ascii_letters + string.digits idx_to_char = {idx: char for idx, char in enumerate(characters)} def to_text(arr): ans='' for c in arr: ans=ans+idx_to_char[c.item()] return ans def predict_captcha(image): try: if image is None: return "No image provided" # Handle Gradio's FileData input: dict with 'data' and 'meta' if isinstance(image, dict) and 'data' in image: image = image['data'] # Convert to PIL.Image if isinstance(image, str) and image.startswith('data:image'): import base64 from io import BytesIO image_data = base64.b64decode(image.split(',')[1]) image = Image.open(BytesIO(image_data)) elif not isinstance(image, Image.Image): from io import BytesIO image = Image.open(BytesIO(image)) # Process image transform = transforms.Compose([ transforms.Resize((40, 150)), transforms.Grayscale(), transforms.ToTensor(), transforms.Lambda(lambda x: x / 255), ]) image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image_tensor) prediction = output.squeeze(0).argmax(axis=1) result = to_text(prediction) print(f"Predicted text: {result}") return result except Exception as e: print(f"Error details: {str(e)}") return f"Error processing image: {str(e)}" # Ensure model is loaded and set to eval mode checkpoint = torch.load(r'model/final.pth', map_location=device) # Restore states print("Checkpoint keys:", checkpoint.keys()) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) model.eval() # Update Gradio interface iface = gr.Interface( fn=predict_captcha, inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"), outputs=gr.Textbox(label="Predicted Text"), title="CAPTCHA Recognition", description="Upload a CAPTCHA image to get the predicted text." ) if __name__ == "__main__": iface.launch( )