CCN / app.py
Ds0uz4's picture
Update app.py
12ff556 verified
import torch
import gradio as gr
from torch import nn
import torch.nn.functional as F
import os
from torchvision import transforms
from torch.utils.data import DataLoader,random_split,Dataset
from PIL import Image,UnidentifiedImageError
import string
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
characters = string.ascii_letters + string.digits
idx_to_char = {idx: char for idx, char in enumerate(characters)}
#AFFN
AFFN_KERNEL=5
AFFN_STRIDE=1
AFFN_DEPTH=1
#CRNN
CRNN_KERNEL=5
CRNN_POOL_KERNEL=2
CRNN_DROPOUT=0.3
CRNN_LATENT=128
LSTM_HIDDEN_DIM=32
VOCAB_SIZE=26*2+10
OUTPUT_LENGTH=5
class Encoder(nn.Sequential):
def __init__(self,n,kernel_size,stride):
super().__init__(
nn.Conv2d(in_channels=4**(n-1),out_channels=4**n,kernel_size=kernel_size,stride=stride),
nn.BatchNorm2d(num_features=4**n),
nn.ReLU(inplace=False)
)
class Decoder(nn.Sequential):
def __init__(self,n,kernel_size,stride):
super().__init__(
nn.ConvTranspose2d(in_channels=4**n,out_channels=4**(n-1),kernel_size=kernel_size,stride=stride),
nn.BatchNorm2d(num_features=4**(n-1)),
nn.ReLU(inplace=False)
)
class AFFN(nn.Module):
def __init__(self,n):
super().__init__()
self.n= n
self.alpha = nn.Parameter(torch.randn(n-1).to(device)).to(device)
self.encoders = []
self.decoders = []
for i in range(1,n+1):
self.encoders.append(Encoder(i,AFFN_KERNEL,AFFN_STRIDE).to(device))
for i in range(n,0,-1):
self.decoders.append(Decoder(i,AFFN_KERNEL,AFFN_STRIDE).to(device))
def forward(self,x):
residuals = []
for i,enc in enumerate(self.encoders):
x= enc(x)
if i < self.n-1:
x = x * (1 - self.alpha[i])
residuals.append(x * self.alpha[i])
for i,dec in enumerate(self.decoders):
x= dec(x)
if i < self.n-1:
x= x + residuals.pop()
return x
class CRNN(nn.Module):
def __init__(self, in_channels, kernel_size, pool_kernel_size, dropout, latent_dim, lstm_hidden_dim, vocab_size, output_length=5):
super().__init__()
self.lstm_hidden_dim = lstm_hidden_dim
self.output_length = output_length # Should be 5 for 5 characters
self.vocab_size = vocab_size
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=in_channels*2, kernel_size=kernel_size, padding=2),
nn.BatchNorm2d(num_features=in_channels*2),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=pool_kernel_size)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=in_channels*2, out_channels=in_channels*4, kernel_size=kernel_size, padding=2),
nn.BatchNorm2d(num_features=in_channels*4),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=pool_kernel_size)
)
self.flatten = nn.Flatten()
self.dropout = nn.Dropout(dropout)
self.latent_fc = nn.LazyLinear(latent_dim)
self.lstm = nn.LSTM(input_size=latent_dim, hidden_size=lstm_hidden_dim, num_layers=1, batch_first=True)
self.output_fc = nn.Linear(lstm_hidden_dim, vocab_size)
def forward(self, x):
batch_size = x.size(0)
# CNN feature extraction
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
flattened = self.flatten(conv2_out)
dropped = self.dropout(flattened)
latent = self.latent_fc(dropped)
lstm_input = latent.unsqueeze(1)
# Initialize hidden and cell states
h0 = torch.zeros(1, batch_size, self.lstm_hidden_dim, device=x.device)
c0 = torch.zeros(1, batch_size, self.lstm_hidden_dim, device=x.device)
outputs = []
# Generate 5 characters sequentially
for _ in range(self.output_length):
out, (h0, c0) = self.lstm(lstm_input, (h0, c0)) # out shape: (batch_size, 1, lstm_hidden_dim)
logits = self.output_fc(out.squeeze(1)) # Shape: (batch_size, vocab_size)
outputs.append(logits)
outputs = torch.stack(outputs, dim=1) # Shape: (batch_size, 5, vocab_size)
return outputs
output=CRNN(64,CRNN_KERNEL,CRNN_POOL_KERNEL,CRNN_DROPOUT,CRNN_LATENT,LSTM_HIDDEN_DIM,VOCAB_SIZE,OUTPUT_LENGTH).to(device)(torch.zeros((2,64,256,256)).to(device))
class CaptchaCrackNet(nn.Module):
def __init__(self):
super().__init__()
self.affn=AFFN(AFFN_DEPTH).to(device)
self.conv1=nn.Sequential(
nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5,padding=2),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2)
)
self.conv2=nn.Sequential(
nn.Conv2d(in_channels=32,out_channels=48,kernel_size=5,padding=2),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2)
)
self.conv3=nn.Sequential(
nn.Conv2d(in_channels=48,out_channels=64,kernel_size=5,padding=2),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2)
)
self.res=nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, stride=2, padding=2)
self.crnn=CRNN(64,CRNN_KERNEL,CRNN_POOL_KERNEL,CRNN_DROPOUT,CRNN_LATENT,LSTM_HIDDEN_DIM,VOCAB_SIZE,OUTPUT_LENGTH).to(device)
def forward(self,x):
affn_out=self.affn(x)
res_out=self.res(x)
conv1_out=self.conv1(affn_out)
conv2_out=self.conv2(conv1_out+res_out)
conv3_out=self.conv3(conv2_out)
output=self.crnn(conv3_out)
return output
torch.manual_seed(42)
model=CaptchaCrackNet().to(device)
optimizer=torch.optim.Adam(model.parameters())
characters = string.ascii_letters + string.digits
idx_to_char = {idx: char for idx, char in enumerate(characters)}
def to_text(arr):
ans=''
for c in arr:
ans=ans+idx_to_char[c.item()]
return ans
def predict_captcha(image):
try:
if image is None:
return "No image provided"
# Handle Gradio's FileData input: dict with 'data' and 'meta'
if isinstance(image, dict) and 'data' in image:
image = image['data']
# Convert to PIL.Image
if isinstance(image, str) and image.startswith('data:image'):
import base64
from io import BytesIO
image_data = base64.b64decode(image.split(',')[1])
image = Image.open(BytesIO(image_data))
elif not isinstance(image, Image.Image):
from io import BytesIO
image = Image.open(BytesIO(image))
# Process image
transform = transforms.Compose([
transforms.Resize((40, 150)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Lambda(lambda x: x / 255),
])
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
prediction = output.squeeze(0).argmax(axis=1)
result = to_text(prediction)
print(f"Predicted text: {result}")
return result
except Exception as e:
print(f"Error details: {str(e)}")
return f"Error processing image: {str(e)}"
# Ensure model is loaded and set to eval mode
checkpoint = torch.load(r'model/final.pth', map_location=device)
# Restore states
print("Checkpoint keys:", checkpoint.keys())
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.eval()
# Update Gradio interface
iface = gr.Interface(
fn=predict_captcha,
inputs=gr.Image(type="pil", label="Upload CAPTCHA Image"),
outputs=gr.Textbox(label="Predicted Text"),
title="CAPTCHA Recognition",
description="Upload a CAPTCHA image to get the predicted text."
)
if __name__ == "__main__":
iface.launch(
)