7eu7d7 commited on
Commit
1f584ff
·
1 Parent(s): da04e0d
Files changed (6) hide show
  1. app.py +77 -0
  2. cap.py +100 -0
  3. models/__init__.py +1 -0
  4. models/enc_dec.py +28 -0
  5. requirements.txt +7 -0
  6. utils.py +27 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from cap import Predictor
9
+
10
+
11
+ @lru_cache()
12
+ def load_predictor(model):
13
+ predictor = Predictor(hf_hub_download(
14
+ f'7eu7d7/CAPTCHA_recognize',
15
+ model,
16
+ ))
17
+ return predictor
18
+
19
+
20
+ def process_image(image):
21
+ """
22
+ Process the uploaded image - this is an example function
23
+ You can modify this function to implement specific image processing logic
24
+ """
25
+ if image is None:
26
+ return "Please upload an image first"
27
+
28
+ # Example processing: convert image to grayscale
29
+ if isinstance(image, np.ndarray):
30
+ # If it's a numpy array, convert to PIL Image
31
+ img = Image.fromarray(image.astype('uint8')).convert('RGB')
32
+ else:
33
+ img = image.convert('RGB')
34
+
35
+ predictor = load_predictor('captcha-2000.safetensors')
36
+ text = predictor.pred_img(img, show=False)
37
+ return text
38
+
39
+
40
+ # Create Gradio interface
41
+ with gr.Blocks(title="CAPTCHA Recognize") as demo:
42
+
43
+ with gr.Row():
44
+ # Left column - Input area
45
+ with gr.Column(scale=1):
46
+ image_input = gr.Image(
47
+ label="Upload CAPTCHA Image",
48
+ type="pil",
49
+ height=300
50
+ )
51
+
52
+ # Run button
53
+ process_btn = gr.Button(
54
+ "Run",
55
+ variant="primary",
56
+ size="lg"
57
+ )
58
+
59
+ # Right column - Output area
60
+ with gr.Column(scale=1):
61
+ text_output = gr.Textbox(
62
+ label="Result",
63
+ lines=4,
64
+ interactive=False
65
+ )
66
+
67
+ # Bind events
68
+ process_btn.click(
69
+ fn=process_image,
70
+ inputs=image_input,
71
+ outputs=[text_output]
72
+ )
73
+
74
+
75
+ # Launch the application
76
+ if __name__ == "__main__":
77
+ demo.launch()
cap.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import argparse
4
+ from models import ResnetEncoderDecoder
5
+ from utils import remove_rptch
6
+ from safetensors import safe_open
7
+ from torchvision import transforms as T
8
+ from PIL import Image
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ char_dict = '_0123456789abcdefghijklmnopqrstuvwxyz'
13
+ id_chr_map = {i: c for i, c in enumerate(char_dict)}
14
+
15
+
16
+ class Predictor:
17
+ def __init__(self, model_path, char_dict=char_dict):
18
+ self.model = ResnetEncoderDecoder(char_dict).to(device)
19
+ self.model.eval()
20
+ if str(device)=='cpu':
21
+ check_point = self.load_safetensor(model_path, map_location=torch.device('cpu'))
22
+ else:
23
+ check_point = self.load_safetensor(model_path)
24
+ self.model.load_state_dict(check_point)
25
+ self.char_dict = char_dict
26
+
27
+ self.trans = T.Compose([
28
+ T.ToTensor(),
29
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
30
+ ])
31
+
32
+ # >>>>> from RainbowNeko Engine >>>>>
33
+ @staticmethod
34
+ def fold_dict(safe_f, split_key=':'):
35
+ dict_fold = {}
36
+
37
+ for k in safe_f.keys():
38
+ k_list = k.split(split_key)
39
+ dict_last = dict_fold
40
+ for item in k_list[:-1]:
41
+ if item not in dict_last:
42
+ dict_last[item] = {}
43
+ dict_last = dict_last[item]
44
+ dict_last[k_list[-1]]=safe_f.get_tensor(k)
45
+
46
+ return dict_fold
47
+
48
+ def load_safetensor(self, ckpt_f, map_location='cpu'):
49
+ with safe_open(ckpt_f, framework="pt", device=map_location) as f:
50
+ sd_fold = self.fold_dict(f)
51
+ return sd_fold
52
+ # <<<<< from RainbowNeko Engine <<<<<
53
+
54
+ def pred(self, input):
55
+ pred = self.model(input.to(device))
56
+
57
+ B, H, W, C = pred.size()
58
+ T_ = H * W
59
+ pred = pred.view(B, T_, -1)
60
+ pred = pred + 1e-10
61
+
62
+ pred_cls = torch.max(pred, 2)[1].data.cpu().numpy()[0]
63
+
64
+ pred_cls = pred_cls.reshape((H, W)).T.reshape((H * W,))
65
+ final_str = remove_rptch(''.join(self.char_dict[x] for x in pred_cls if x))
66
+
67
+ return pred_cls, final_str, (H, W)
68
+
69
+ def pred_img(self, image, show=True):
70
+ if isinstance(image, str):
71
+ image = Image.open(image).convert('RGB')
72
+ image = self.trans(image)
73
+ pred_cls, final_str, (H, W) = self.pred(image.unsqueeze(0))
74
+
75
+ if show:
76
+ pred_string = ''.join(['%2s' % self.char_dict[pn] for pn in pred_cls])
77
+ pred_string_set = [pred_string[i:i + W * 2] for i in range(0, len(pred_string), W * 2)]
78
+ print('Prediction: ')
79
+ for pre_str in pred_string_set:
80
+ print(pre_str)
81
+
82
+ print('Result:', final_str)
83
+
84
+ return final_str
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser(description='CAPTCHA Recognizer')
88
+ parser.add_argument('--model_path', type=str, default='exps/captcha/ckpts/model-2000.safetensors', help='Path to the model file')
89
+ parser.add_argument('--image_path', type=str, default=[
90
+ '/data1/dzy/CAPTCHA_recognize/data3/test/2.jpg',
91
+ '/data1/dzy/Verification_Code_CV_v1.1/imgs/00097.png',
92
+ '/data1/dzy/Verification_Code_CV_v1.1/imgs/00098.png',
93
+ '/data1/dzy/Verification_Code_CV_v1.1/imgs/00099.png',
94
+ ], nargs='+', help='Path to the image file')
95
+ args = parser.parse_args()
96
+
97
+ predictor = Predictor(args.model_path)
98
+ for path in args.image_path:
99
+ result = predictor.pred_img(path)
100
+ print(f'Recognized CAPTCHA: {result}')
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .enc_dec import ResnetEncoderDecoder
models/enc_dec.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import timm
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class ResnetEncoderDecoder(nn.Module):
10
+ def __init__(self, char_dict):
11
+ super(ResnetEncoderDecoder, self).__init__()
12
+ self.bn = nn.BatchNorm2d(64)
13
+ resnet = timm.create_model('resnet18', pretrained=True, drop_rate=0.2, drop_path_rate=0.3)
14
+ self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1)
15
+ self.cnn = nn.Sequential(*list(resnet.children())[4:-2])
16
+ self.out = nn.Linear(512, len(char_dict))
17
+
18
+ self.char_dict = char_dict
19
+
20
+ def forward(self, input):
21
+ input = F.silu(self.bn(self.conv(input)), True)
22
+ input = F.max_pool2d(input, kernel_size=(2, 2), stride=(2, 2))
23
+ input = self.cnn(input)
24
+
25
+ input = input.permute(0, 2, 3, 1)
26
+ input = F.softmax(self.out(input), dim=-1)
27
+
28
+ return input
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ timm
5
+ safetensors
6
+ numpy
7
+ huggingface_hub
utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def rmchr(text,index):
3
+ return text[:index]+text[index+1:]
4
+
5
+ def count_rptch(text):
6
+ maxch=(1,0)
7
+ nowch=(0,0)
8
+ lastch=None
9
+ for index,i in enumerate(text):
10
+ if lastch == i:
11
+ nowch = (nowch[0]+1,nowch[1])
12
+ if nowch[0]>maxch[0]:
13
+ maxch=nowch
14
+ else:
15
+ nowch=(1,index)
16
+
17
+ lastch=i
18
+
19
+ return maxch
20
+
21
+ def remove_rptch(text,tar_len=4):
22
+ while len(text)>tar_len:
23
+ maxch = count_rptch(text)
24
+ if maxch[0]<=1:
25
+ break
26
+ text=rmchr(text,maxch[1])
27
+ return text