LTPhat commited on
Commit
1f1fc6b
·
1 Parent(s): 954d8ce
create_dataset/create_fontstyle.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import random
3
+ from PIL import Image, ImageDraw, ImageFont
4
+
5
+ font_folder = 'font'
6
+ font_name = ['arial', 'bodoni','calibri','futura','heveltica','times-new-roman']
7
+
8
+ def fontstyle_list(font_folder, font_name):
9
+ font_list = []
10
+ for i in font_name:
11
+ font_dir = glob.glob(font_folder + "\\"+ i +"\\*.ttf")
12
+ for j in font_dir:
13
+ font_list.append(j)
14
+ return font_list
15
+
16
+
17
+
18
+ def draw_img(label, font_list):
19
+ img = Image.new('L', (256, 256))
20
+ size = random.randint(150, 250)
21
+ x = random.randint(60, 90)
22
+ y = random.randint(30, 60)
23
+ draw = ImageDraw.Draw(img)
24
+ # font = ImageFont.truetype(, )
25
+ font = ImageFont.truetype(font_list[0], size)
26
+ draw.text((x, y), str(label), (200),font=font)
27
+
28
+ img = img.resize((28, 28), Image.BILINEAR)
29
+ return img, label
30
+
31
+ if __name__ == "__main__":
32
+ fonts = fontstyle_list(font_folder, font_name)
33
+ print(fonts)
34
+ print(len(fonts))
create_dataset/data_generation_sample.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFont, ImageDraw
2
+ import random
3
+ import cv2
4
+ import numpy as np
5
+
6
+ #font size
7
+ fonts = ['font\Arial\Arial.ttf', 'font\calibri\Calibri-Bold.ttf','font\heveltica\helvetica_bold.ttf','font\Times-new-roman\Times.ttf']
8
+
9
+ img = Image.new('L', (256, 256))
10
+
11
+ target = random.randint(0, 9)
12
+
13
+ size = random.randint(150, 250)
14
+ x = random.randint(60, 90)
15
+ y = random.randint(30, 60)
16
+ draw = ImageDraw.Draw(img)
17
+ # font = ImageFont.truetype(, )
18
+ font = ImageFont.truetype(fonts[random.randint(0,3)], size)
19
+ draw.text((x, y), str(target), (200),font=font)
20
+
21
+ img = img.resize((28, 28), Image.BILINEAR)
22
+ img = np.array(img)
23
+ cv2.imshow("Image", img)
24
+ cv2.waitKey(0)
create_dataset/digital_mnist_digits.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from create_fontstyle import fontstyle_list
3
+ import torch
4
+
5
+ from PIL import Image
6
+ from PIL import ImageFont
7
+ from PIL import ImageDraw
8
+
9
+ import glob
10
+ import random
11
+ import os
12
+
13
+ font_folder = 'font'
14
+ font_name = ['arial', 'bodoni','calibri','futura','heveltica','times-new-roman']
15
+
16
+ fonts = fontstyle_list(font_folder, font_name)
17
+
18
+ class PrintedMNIST(Dataset):
19
+ """
20
+ Generate digital mnist dataset for digits recognition
21
+ """
22
+ def __init__(self, samples, random_state, transform = None):
23
+ self.samples = samples
24
+ self.random_state = random_state
25
+ self.transfrom = transform
26
+ self.fonts = fonts
27
+
28
+ random.seed(random_state)
29
+ def __len__(self):
30
+ return self.samples
31
+
32
+ def __getitem__(self, index):
33
+ color = random.randint(200,255)
34
+ #Generate image
35
+ img = Image.new("L",(256, 256))
36
+ label = random.randint(0,9)
37
+ size = random.randint(180, 220)
38
+ x = random.randint(60, 80)
39
+ y = random.randint(30, 60)
40
+
41
+ draw = ImageDraw.Draw(img)
42
+ #Choose random font style in font style list
43
+ font = ImageFont.truetype(random.choice(self.fonts), size)
44
+ draw.text((x,y), str(label), color, font = fonts)
45
+
46
+ img = img.resize((28,28), Image.BILINEAR)
47
+ if self.transfrom:
48
+ img = self.transfrom(img)
49
+ return img, label
50
+
51
+ class AddSPNoise(object):
52
+ def __init__(self, prob):
53
+ self.prob = prob
54
+
55
+ def __call__(self, tensor):
56
+ sp = (torch.rand(tensor.size()) < self.prob) * tensor.max()
57
+ return tensor + sp
58
+
59
+ def __repr__(self):
60
+ return self.__class__.__name__ + "(prob={0})".format(self.prob)
61
+
62
+
63
+ class AddGaussianNoise(object):
64
+ def __init__(self, mean=0.0, std=1.0):
65
+ self.mean = mean
66
+ self.std = std
67
+
68
+ def __call__(self, tensor):
69
+ return tensor + torch.randn(tensor.size()) * self.std + self.mean
70
+
71
+ def __repr__(self):
72
+ return self.__class__.__name__ + "(mean={0}, std={1})".format(
73
+ self.mean, self.std
74
+ )
helper_number_page.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from processing import *
4
+ from utils import *
5
+ from sudoku_solve import Sudoku_solver
6
+
7
+
8
+ input_str = "000000000008236400010050020500000009100000007080000050005000200000807000000020000"
9
+ input_str2 = "800010009050807010004090700060701020508060107010502090007040600080309040300050008"
10
+
11
+ def draw_grid():
12
+ base_img = 1* np.ones((600,600,3))
13
+ width = base_img.shape[0] // 9
14
+ cv2.rectangle(base_img, (0,0), (base_img.shape[0], base_img.shape[1]), (0,0,0), 10)
15
+ for i in range(1,10):
16
+ if i % 3 == 0:
17
+ cv2.line(base_img, (i*width, 0), (i*width, base_img.shape[1]), (0,0,0), 6)
18
+ cv2.line(base_img, (0, i* width), (base_img.shape[0], i*width), (0,0,0), 6)
19
+ else:
20
+ cv2.line(base_img, (i*width, 0), (i*width, base_img.shape[1]), (0,0,0), 2)
21
+ cv2.line(base_img, (0, i* width), (base_img.shape[0], i*width), (0,0,0), 2)
22
+ return base_img
23
+
24
+
25
+ def draw_digit(base_img, input_str):
26
+ width = base_img.shape[0] // 9
27
+ board = convert_str_to_board(input_str)
28
+ for j in range(9):
29
+ for i in range(9):
30
+ if board[j][i] !=0 : # Only draw new number to blank cell in warped image, avoid overlapping
31
+
32
+ p1 = (i * width, j * width) # Top left corner of a bounding box
33
+ p2 = ((i + 1) * width, (j + 1) * width) # Bottom right corner of bounding box
34
+
35
+ # Find the center of square to draw digit
36
+ center = ((p1[0] + p2[0]) // 2, (p1[1] + p2[1]) // 2)
37
+ text_size, _ = cv2.getTextSize(str(board[j][i]), cv2.FONT_HERSHEY_SIMPLEX, 1, 6)
38
+ text_origin = (center[0] - text_size[0] // 2, center[1] + text_size[1] // 2)
39
+
40
+ cv2.putText(base_img, str(board[j][i]),
41
+ text_origin, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 0), 6)
42
+ return base_img, board
43
+
44
+ def solve(board):
45
+ unsolved_board = board.copy()
46
+ sudoku = Sudoku_solver(board, 9)
47
+ sudoku.solve()
48
+ res_board = sudoku.board
49
+ return res_board, unsolved_board
50
+
51
+
52
+ def draw_result(base_img, unsolved_board, solved_board):
53
+ width = base_img.shape[0] // 9
54
+ for j in range(9):
55
+ for i in range(9):
56
+ p1 = (i * width, j * width) # Top left corner of a bounding box
57
+ p2 = ((i + 1) * width, (j + 1) * width) # Bottom right corner of bounding box
58
+
59
+ # Find the center of square to draw digit
60
+ center = ((p1[0] + p2[0]) // 2, (p1[1] + p2[1]) // 2)
61
+ text_size, _ = cv2.getTextSize(str(solved_board[j][i]), cv2.FONT_HERSHEY_SIMPLEX, 1, 6)
62
+ text_origin = (center[0] - text_size[0] // 2, center[1] + text_size[1] // 2)
63
+ if unsolved_board[j][i] != solved_board[j][i]:
64
+ cv2.putText(base_img, str(solved_board[j][i]),
65
+ text_origin, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 6)
66
+ else:
67
+ cv2.putText(base_img, str(solved_board[j][i]),
68
+ text_origin, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 0), 6)
69
+ return base_img
70
+
71
+ ### CHECK VALID SODOKU PUZZLE INPUT FROM USER
72
+ def get_column(board, index):
73
+ return np.array([row[index] for row in board])
74
+
75
+
76
+ def valid_row_or_col(array):
77
+ if np.all(array == 0) == True:
78
+ return True
79
+ return len(set(array[array!=0])) == len(list(array[array!=0]))
80
+
81
+ def valid_single_box(board, box_x, box_y):
82
+ box = board[box_x*3 : box_x*3 + 3, box_y*3: box_y*3+3]
83
+ if len(list(box[box!=0])) == 0:
84
+ return True
85
+ return len(set(box[box!=0])) == len(list(box[box!=0]))
86
+
87
+ def valid_input_str(input_str):
88
+ board = convert_str_to_board(input_str)
89
+ # Check valid row
90
+ for i in range(0,len(board)):
91
+ if valid_row_or_col(board[i]) == False:
92
+ return False
93
+ # Check valid column
94
+ for j in range(0, len(board[0])):
95
+ if valid_row_or_col(get_column(board, j)) == False:
96
+ return False
97
+ # Check valid box
98
+ for i in range(0, 3):
99
+ for j in range(0, 3):
100
+ if valid_single_box(board, i, j) == False:
101
+ return False
102
+ return True
103
+
104
+ def valid_board(board):
105
+ # Check valid row
106
+ for i in range(0,len(board)):
107
+ if valid_row_or_col(board[i]) == False:
108
+ return False
109
+ # Check valid column
110
+ for j in range(0, len(board[0])):
111
+ if valid_row_or_col(get_column(board, j)) == False:
112
+ return False
113
+ # Check valid box
114
+ for i in range(0, 3):
115
+ for j in range(0, 3):
116
+ if valid_single_box(board, i, j) == False:
117
+ return False
118
+ return True
119
+
120
+ if __name__ == "__main__":
121
+ base_img = draw_grid()
122
+ res_img = base_img.copy()
123
+ base_img, board = draw_digit(base_img, input_str)
124
+ cv2.imshow("IMG", base_img)
125
+ cv2.waitKey(0)
126
+
127
+ res_board, unsolved_board = solve(board)
128
+ res_img = draw_result(res_img, unsolved_board, res_board)
129
+ cv2.imshow("Show result", res_img)
130
+ cv2.waitKey(0)
131
+ res = valid_input_str(input_str2)
132
+ print(res)
image_solver.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from utils import *
4
+ from processing import *
5
+ from threshold import preprocess
6
+ import time
7
+ import cv2
8
+ from sudoku_solve import Sudoku_solver
9
+ import matplotlib.pyplot as plt
10
+
11
+ # This module performs sudoku solver which input is a image file.
12
+
13
+ classifier = torch.load('digit_model.h5',map_location ='cpu')
14
+ classifier.eval()
15
+
16
+
17
+ def image_solver(img, model):
18
+ original_img = img.copy()
19
+ threshold = preprocess(img)
20
+ corners_img, corners_list, org_img = find_contours(threshold, img)
21
+ try:
22
+ # Warped original img
23
+ warped, matrix = warp_image(corners_list, corners_img)
24
+ # Threshold warped img
25
+ warped_processed = preprocess(warped) # warped_processed is gray-scaled img
26
+
27
+ #Get lines
28
+ horizontal = grid_line_helper(warped_processed, shape_location=0)
29
+ vertical = grid_line_helper(warped_processed, shape_location=1)
30
+
31
+ # Create mask
32
+ if img.shape[0] > 600 or img.shape[1] > 600:
33
+ # Resize will get better result ??
34
+ grid_mask = create_grid_mask(horizontal, vertical)
35
+ grid_mask = cv2.resize(grid_mask,(600,600), cv2.INTER_AREA)
36
+ number_img = cv2.bitwise_and(cv2.resize(warped_processed, (600,600), cv2.INTER_AREA), grid_mask)
37
+ else:
38
+ grid_mask = create_grid_mask(horizontal, vertical)
39
+ # Extract number
40
+ number_img = cv2.bitwise_and(warped_processed, grid_mask)
41
+ # Split into squares
42
+ squares = split_squares(number_img)
43
+ cleaned_squares = clean_square_all_images(squares)
44
+
45
+ # Resize and scale pixel
46
+ resized_list = resize_square(cleaned_squares)
47
+ norm_resized = normalize(resized_list)
48
+
49
+ # # Recognize digits
50
+ rec_str = recognize_digits(model, norm_resized, original_img)
51
+ board = convert_str_to_board(rec_str)
52
+
53
+ # Solve
54
+ unsolved_board = board.copy()
55
+ sudoku = Sudoku_solver(board, 9)
56
+ start_time = time.time()
57
+ sudoku.solve()
58
+ solved_board = sudoku.board
59
+ # Unwarp
60
+ _, warp_with_nums = draw_digits_on_warped(warped, solved_board, unsolved_board)
61
+
62
+ dst_img = unwarp_image(warp_with_nums, corners_img, corners_list, time.time() - start_time)
63
+ return dst_img, solved_board
64
+ except TypeError:
65
+ print("Can not warp image. Please try another image")
66
+
67
+ if __name__ == "__main__":
68
+ url = "streamlit_app\image_from_user\Test40.jpg" # Url for test image
69
+ res, solved_board = image_solver(url, classifier)
70
+ cv2.imshow("Result", cv2.resize(res, (700,700), cv2.INTER_AREA))
71
+ cv2.waitKey(0)
72
+
model/evaluation.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFont, ImageDraw
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from train_classifier import model, device, train_loader, val_loader, train_acc, train_loss, val_loss, val_acc
8
+
9
+ # Just visualize model results
10
+
11
+ def visualize_sample():
12
+ """
13
+ Visualize sample in dataloader
14
+ """
15
+ train_features, train_labels = next(iter(train_loader))
16
+ print(f"Feature batch shape: {train_features.size()}")
17
+ print(f"Labels batch shape: {train_labels.size()}")
18
+ img = train_features[0].squeeze()
19
+ label = train_labels[0]
20
+ plt.imshow(img, cmap="gray")
21
+ plt.show()
22
+ print(f"Label: {label}")
23
+
24
+
25
+
26
+ def plot_metrics(train_loss, train_acc, val_loss, val_acc):
27
+ fig, ag = plt.subplots(1,2,figsize = (15,6))
28
+ ag[0].plot(train_loss,label = 'train')
29
+ ag[0].plot(val_loss,label = 'val')
30
+ ag[0].legend()
31
+ ag[0].set_title('Loss versus epochs')
32
+
33
+ ag[1].plot(train_acc,label='train')
34
+ ag[1].plot(val_acc,label='test')
35
+ ag[1].legend()
36
+ ag[1].set_title('Accuracy versus epochs')
37
+ plt.show()
38
+
39
+
40
+ def predict_batch(model, data_loader):
41
+ """
42
+ Get prediction on one random batch
43
+ """
44
+
45
+ batch_id = np.random.randint(0, len(data_loader))
46
+ for index, batch in enumerate(data_loader):
47
+ if index == batch_id:
48
+ inputs, labels = batch[0], batch[1]
49
+ model = model.to(device)
50
+ inputs = inputs.to(device)
51
+ outputs = model(inputs)
52
+ preds = outputs.argmax(dim=1)
53
+ preds=preds.cpu().numpy()
54
+ labels=labels.numpy()
55
+ return inputs, preds, labels
56
+
57
+
58
+ if __name__ == "__main__":
59
+ visualize_sample()
60
+ plot_metrics(train_loss, train_acc, val_loss, val_acc)
61
+ inputs, preds, labels = predict_batch(model, val_loader)
62
+ print(preds)
63
+ print(labels)
64
+ print("Accuracy on random batch: {}/{}".format(np.sum(preds==labels), len(preds)))
model/get_model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.models import resnet18, resnet101, resnet50
3
+ import torchvision
4
+ import torch.nn as nn
5
+
6
+
7
+ def get_model(model_name, pretrained=True):
8
+ if model_name == "resnet18":
9
+ net = torchvision.models.resnet18(pretrained=pretrained)
10
+
11
+ # Replace 1st layer to use it on grayscale images
12
+ net.conv1 = nn.Conv2d(
13
+ 1,
14
+ 64,
15
+ kernel_size=(7, 7),
16
+ stride=(2, 2),
17
+ padding=(3, 3),
18
+ bias=False,
19
+ )
20
+ net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
21
+ if model_name == "resnet50":
22
+
23
+ net = torchvision.models.resnet50(pretrained=pretrained)
24
+
25
+ # Replace 1st layer to use it on grayscale images
26
+ net.conv1 = nn.Conv2d(
27
+ 1,
28
+ 64,
29
+ kernel_size=(7, 7),
30
+ stride=(2, 2),
31
+ padding=(3, 3),
32
+ bias=False,
33
+ )
34
+ net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
35
+ if model_name == "resnet101":
36
+
37
+ net = torchvision.models.resnet101(pretrained=pretrained)
38
+
39
+ # Replace 1st layer to use it on grayscale images
40
+ net.conv1 = nn.Conv2d(
41
+ 1,
42
+ 64,
43
+ kernel_size=(7, 7),
44
+ stride=(2, 2),
45
+ padding=(3, 3),
46
+ bias=False,
47
+ )
48
+ net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
49
+
50
+ return net
model/train_classifier.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFont, ImageDraw
2
+ import random
3
+ import glob
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import torch.nn as nn
6
+ import torchvision.transforms as transforms
7
+ import numpy as np
8
+ import torch.optim as optim
9
+ import torch
10
+ import time
11
+ import copy
12
+ from create_dataset.digital_mnist_digits import PrintedMNIST
13
+ from get_model import get_model
14
+
15
+ # Define parameters
16
+ batch_size = 64
17
+ net = get_model("resnet50")
18
+ n_epochs = 10
19
+ device = "cuda" if torch.cuda.is_available() == True else "cpu"
20
+
21
+ #Define optimizer
22
+ learning_rate = 1e-3
23
+ criterion = nn.CrossEntropyLoss()
24
+ optimizer = optim.Adam(net.parameters(), lr=learning_rate)
25
+
26
+
27
+ def load_dataset(batch_size):
28
+ """
29
+ Load dataset using Pytorch Dataloader
30
+ """
31
+ train_transform = transforms.Compose([
32
+ transforms.RandomRotation(10),
33
+ transforms.ToTensor(),
34
+ # AddGaussianNoise(0, 1.0),
35
+ # AddSPNoise(0.1),
36
+
37
+ ])
38
+ val_transforms = transforms.Compose([transforms.ToTensor()])
39
+
40
+ train_set = PrintedMNIST(50000, 42, train_transform)
41
+ val_set = PrintedMNIST(5000, 33, val_transforms)
42
+
43
+ train_loader = DataLoader(train_set, batch_size=batch_size)
44
+ val_loader = DataLoader(val_set, batch_size=batch_size)
45
+ return train_loader, val_loader
46
+
47
+
48
+ def train(model, train_loader, val_loader, criterion, optimizer, num_epochs):
49
+ # since = time.time()
50
+ best_model_wts = copy.deepcopy(model.state_dict())
51
+ best_acc = 0.0
52
+ train_loss_track = []
53
+ train_acc_track = []
54
+ val_loss_track = []
55
+ val_acc_track = []
56
+ for epoch in range(num_epochs):
57
+ print(f'Epoch {epoch + 1}/{num_epochs}')
58
+ print('-' * 10)
59
+
60
+ # Training loop
61
+ train_loss, train_correct = 0, 0
62
+ model.train()
63
+ for batch in train_loader:
64
+ images, labels = batch[0].to(device), batch[1].to(device) # load the batch to the available device (cpu/gpu)
65
+ outputs = model(images)
66
+ loss = criterion(outputs, labels)
67
+ preds = outputs.argmax(dim=1).cpu().numpy()
68
+ optimizer.zero_grad()
69
+ loss.backward()
70
+ optimizer.step()
71
+ np_labels_train = labels.cpu().numpy()
72
+ train_loss += loss.item() * batch_size
73
+ train_correct += np.sum(preds == np_labels_train)
74
+ train_loss_avg = train_loss/len(train_loader.sampler)
75
+ train_acc_avg = train_correct/len(train_loader.sampler)
76
+ print('Train Loss: ',train_loss_avg)
77
+ print('Train Accuracy: ', train_acc_avg)
78
+ train_loss_track.append(train_loss_avg)
79
+ train_acc_track.append(train_acc_avg)
80
+
81
+ #Validation loop
82
+ model.eval()
83
+ with torch.no_grad():
84
+ valid_loss, valid_correct = 0, 0
85
+
86
+ for batch in val_loader:
87
+ images, labels = batch[0].to(device), batch[1].to(device) # load the batch to the available device
88
+ outputs = model(images)
89
+ loss = criterion(outputs, labels)
90
+ preds = outputs.argmax(dim=1).cpu().numpy()
91
+ np_label_val = labels.cpu().numpy()
92
+ valid_loss += loss.item() * batch_size
93
+ valid_correct += np.sum(preds == np_label_val)
94
+ if valid_correct > best_acc:
95
+ best_acc = valid_correct
96
+ best_model_wts = copy.deepcopy(model.state_dict())
97
+ valid_loss_avg = valid_loss/len(val_loader.sampler)
98
+ valid_acc_avg = valid_correct/len(val_loader.sampler)
99
+ print('Validation Loss: ', valid_loss_avg)
100
+ print('Validation Accuracy: ',valid_acc_avg)
101
+ val_loss_track.append(valid_loss_avg)
102
+ val_acc_track.append(valid_acc_avg)
103
+
104
+ # Return model with best metrics
105
+ model.load_state_dict(best_model_wts)
106
+ return model, train_loss_track, train_acc_track, val_loss_track, val_acc_track
107
+
108
+ if __name__ == "__main__":
109
+ train_loader, val_loader = load_dataset(batch_size)
110
+ model, train_loss, train_acc, val_loss, val_acc = train(net, train_loader, val_loader, criterion, optimizer, n_epochs)
processing.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from threshold import preprocess
4
+ from utils import find_corners, draw_circle_at_corners, grid_line_helper, draw_line
5
+ from utils import clean_square_helper, classify_one_digit
6
+
7
+ #----------------Process pipe line------------------------------#
8
+
9
+ # 1) Threshold Adaptive to get gray-scale image to find contours
10
+ # 2) Find contours from original image
11
+ # 3) Image alignment (warp image) on original image
12
+ # 4) Get horizontal, vertical line and create grid mask
13
+ # 5) Extract numbers and split gray-scale image into 81 squares
14
+ # 6) Clean noise pixels of each square
15
+ # 7) Recognize digits
16
+ # 8) Solve sudoku
17
+ # 9) Draw solved board on warped image
18
+ # 10) Unwarped image --> Result
19
+
20
+
21
+ def find_contours(img, original):
22
+ """
23
+ contours: A tuple of all point creating contour lines, each contour is a np array of points (x,y).
24
+ hierachy: [Next, Previous, First_Child, Parent]
25
+ contour approximation: https://pyimagesearch.com/2021/10/06/opencv-contour-approximation/
26
+ """
27
+
28
+ # find contours on threshold image
29
+ contours, hierachy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
30
+ #sort the largest contour to find the puzzle
31
+ contours = sorted(contours, key = cv2.contourArea, reverse = True)
32
+ polygon = None
33
+ # find the largest rectangle-shape contour to make sure this is the puzzle
34
+ for con in contours:
35
+ area = cv2.contourArea(con)
36
+ perimeter = cv2.arcLength(con, closed = True)
37
+ approx = cv2.approxPolyDP(con, epsilon=0.01 * perimeter, closed = True)
38
+ num_of_ptr = len(approx)
39
+ if num_of_ptr == 4 and area > 1000:
40
+ polygon = con #finded puzzle
41
+ break
42
+ if polygon is not None:
43
+ # find corner
44
+ top_left = find_corners(polygon, limit_func= min, compare_func= np.add)
45
+ top_right = find_corners(polygon, limit_func= max, compare_func= np.subtract)
46
+ bot_left = find_corners(polygon,limit_func=min, compare_func= np.subtract)
47
+ bot_right = find_corners(polygon,limit_func=max, compare_func=np.add)
48
+ #Check polygon is square, if not return []
49
+ #Set threshold rate for width and height to determine square bounding box
50
+ if not (0.5 < ((top_right[0]-top_left[0]) / (bot_right[1]-top_right[1]))<1.5):
51
+ print("Exception 1 : Get another image to get square-shape puzzle")
52
+ return [],[],[]
53
+ if bot_right[1] - top_right[1] == 0:
54
+ print("Exception 2 : Get another image to get square-shape puzzle")
55
+ return [],[],[]
56
+ corner_list = [top_left, top_right, bot_right, bot_left]
57
+ draw_original = original.copy()
58
+ cv2.drawContours(draw_original, [polygon], 0, (0,255,0), 3)
59
+ #draw circle at each corner point
60
+ for x in corner_list:
61
+ draw_circle_at_corners(draw_original, x)
62
+
63
+ return draw_original, corner_list, original
64
+ # draw_original: Img which drown contour and corner
65
+ # corner_list: list of 4 corner points
66
+ # original: Original imgs
67
+ print("Can not detect puzzle")
68
+ return [],[],[]
69
+
70
+
71
+
72
+ def warp_image(corner_list, original):
73
+ """
74
+ Input: 4 corner points and threshold grayscale image
75
+ Output: Perspective transformation matrix and transformed image
76
+ Perspective transformation: https://theailearner.com/tag/cv2-warpperspective/
77
+ """
78
+ try:
79
+ corners = np.array(corner_list, dtype= "float32")
80
+ top_left, top_right, bot_left, bot_right = corners[0], corners[1], corners[2], corners[3]
81
+ #Get the largest side to be the side of squared transfromed puzzle
82
+ side = int(max([
83
+ np.linalg.norm(top_right - bot_right),
84
+ np.linalg.norm(top_left - bot_left),
85
+ np.linalg.norm(bot_right - bot_left),
86
+ np.linalg.norm(top_left - top_right)
87
+ ]))
88
+ out_ptr = np.array([[0,0],[side-1,0],[side-1,side-1], [0,side-1]],dtype="float32")
89
+ transfrom_matrix = cv2.getPerspectiveTransform(corners, out_ptr)
90
+ transformed_image = cv2.warpPerspective(original, transfrom_matrix, (side, side))
91
+ return transformed_image, transfrom_matrix
92
+ except IndexError:
93
+ print("Can not detect corners")
94
+ except:
95
+ print("Something went wrong. Try another image")
96
+
97
+
98
+
99
+
100
+ def get_grid_line(img, length = 10):
101
+ """
102
+ Get horizontal and vertical lines from warped image
103
+ """
104
+
105
+ horizontal = grid_line_helper(img, shape_location= 1)
106
+ vertical = grid_line_helper(img, shape_location=0)
107
+ return vertical, horizontal
108
+
109
+
110
+
111
+
112
+ def create_grid_mask(horizontal, vertical):
113
+ """
114
+ Completely detect all lines by using Hough Transformation
115
+ Create grid mask to extract number by using bitwise_and with warped images
116
+ """
117
+ # combine two line to make a grid
118
+ grid = cv2.add(horizontal, vertical)
119
+ # Apply threshold to cover more area
120
+
121
+ # grid = cv2.adaptiveThreshold(grid, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 235, 2)
122
+ morpho_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
123
+ grid = cv2.dilate(grid, morpho_kernel, iterations=2)
124
+ # find the line by Houghline transfromation
125
+ lines = cv2.HoughLines(grid, 0.3, np.pi/90, 200)
126
+ lines_img = draw_line(grid, lines)
127
+ # Extract all the lines
128
+ mask = cv2.bitwise_not(lines_img)
129
+ return mask
130
+
131
+
132
+
133
+
134
+ def split_squares(number_img):
135
+ """
136
+ Split number img into 81 squares.
137
+ """
138
+ square_list = []
139
+ side = number_img.shape[0] // 9
140
+
141
+ #find each square and append to square_list
142
+ for j in range(0,9):
143
+ for i in range(0,9):
144
+ top_left_square = (i * side, j * side)
145
+ bot_right_square = ((i+1) * side, (j+1) * side)
146
+ square_list.append(number_img[top_left_square[1]:bot_right_square[1], top_left_square[0]: bot_right_square[0]])
147
+
148
+ return square_list
149
+
150
+
151
+
152
+
153
+ def clean_square(square_list):
154
+ """
155
+ Return cleaned-square list and number of digits available in the image
156
+ Clean-square list has both 0 and images
157
+ """
158
+
159
+ cleaned_squares = []
160
+ count = 0
161
+
162
+ for sq in square_list:
163
+ new_img, is_num = clean_square_helper(sq)
164
+ if is_num:
165
+ cleaned_squares.append(new_img)
166
+ count += 1
167
+ else:
168
+ cleaned_squares.append(0)
169
+ return cleaned_squares, count
170
+
171
+
172
+
173
+ def clean_square_all_images(square_list):
174
+ """
175
+ Return cleaned-square list
176
+ Clean-square list has all images(images with no number with be black image after cleaning)
177
+ """
178
+
179
+ square_cleaned_list = []
180
+ for i in square_list:
181
+ clean_square, _ = clean_square_helper(i)
182
+ square_cleaned_list.append(clean_square)
183
+ return square_cleaned_list
184
+
185
+ def recognize_digits(model, resized, org_img):
186
+ res_str = ""
187
+ for img in resized:
188
+ digit = classify_one_digit(model, img, org_img)
189
+ res_str += str(digit)
190
+ return res_str
191
+
192
+
193
+
194
+ def draw_digits_on_warped(warped_img, solved_board, unsolved_board):
195
+ """
196
+ Function to draw digits from solved board to warped img
197
+ """
198
+
199
+ width = warped_img.shape[0] // 9
200
+
201
+ img_w_text = np.zeros_like(warped_img)
202
+
203
+
204
+ for j in range(9):
205
+ for i in range(9):
206
+ if unsolved_board[j][i] == 0: # Only draw new number to blank cell in warped image, avoid overlapping
207
+
208
+ p1 = (i * width, j * width) # Top left corner of a bounding box
209
+ p2 = ((i + 1) * width, (j + 1) * width) # Bottom right corner of bounding box
210
+
211
+ # Find the center of square to draw digit
212
+ center = ((p1[0] + p2[0]) // 2, (p1[1] + p2[1]) // 2)
213
+ text_size, _ = cv2.getTextSize(str(solved_board[j][i]), cv2.FONT_HERSHEY_SIMPLEX, 1, 6)
214
+ text_origin = (center[0] - text_size[0] // 2, center[1] + text_size[1] // 2)
215
+
216
+ cv2.putText(warped_img, str(solved_board[j][i]),
217
+ text_origin, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 6)
218
+
219
+ return img_w_text, warped_img
220
+
221
+
222
+
223
+ def unwarp_image(img_src, img_dest, pts, time):
224
+ pts = np.array(pts)
225
+
226
+ height, width = img_src.shape[0], img_src.shape[1]
227
+ pts_source = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], [0, width - 1]],
228
+ dtype='float32')
229
+
230
+ matrix, status = cv2.findHomography(pts_source, pts)
231
+ # Covert to original view perspective
232
+ warped = cv2.warpPerspective(img_src, matrix, (img_dest.shape[1], img_dest.shape[0]))
233
+ # Draw a black rectangle in img_dest
234
+
235
+ cv2.fillConvexPoly(img_dest, pts, 0, 16)
236
+ dst_img = cv2.add(img_dest, warped)
237
+ dst_img_height, dst_img_width = dst_img.shape[0], dst_img.shape[1]
238
+ cv2.putText(dst_img, "Time solved: {} s".format(str(np.round(time,4))), (dst_img_width - 360, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
239
+
240
+ return dst_img
241
+
realtime_solver.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from utils import *
4
+ from processing import *
5
+ from threshold import preprocess
6
+ import time
7
+ import cv2
8
+ from sudoku_solve import Sudoku_solver
9
+ from PIL import Image
10
+ from helper_number_page import valid_board
11
+
12
+
13
+ classifier = torch.load('digit_classifier.h5',map_location ='cpu')
14
+ classifier.eval()
15
+
16
+
17
+ frameWidth = 960
18
+ frameHeight = 720
19
+
20
+ cap = cv2.VideoCapture(0)
21
+ frame_rate = 60
22
+
23
+ # width is id number 3, height is id 4
24
+ cap.set(3, frameWidth)
25
+ cap.set(4, frameHeight)
26
+
27
+ # change brightness
28
+ cap.set(10, 150)
29
+ prev = 0
30
+
31
+ while cap.isOpened():
32
+ time_elapsed = time.time() - prev
33
+ success, img = cap.read()
34
+ if time_elapsed > 1. / frame_rate:
35
+ prev = time.time()
36
+ final_img = img.copy()
37
+ to_process_img = img.copy()
38
+ #Processing
39
+ thresholded_img = preprocess(to_process_img) # Gray-scale img
40
+ corners_img, corners_list, org_img = find_contours(thresholded_img, to_process_img)
41
+
42
+ if corners_list:
43
+ # Warped original img
44
+ warped, matrix = warp_image(corners_list, corners_img)
45
+ # Threshold warped img
46
+ warped_processed = preprocess(warped) # warped_processed is gray-scaled img
47
+
48
+ #Get lines
49
+ horizontal = grid_line_helper(warped_processed, shape_location=0)
50
+ vertical = grid_line_helper(warped_processed, shape_location=1)
51
+
52
+ # Create mask
53
+ grid_mask = create_grid_mask(horizontal, vertical)
54
+ # Resize will get better result ??
55
+ grid_mask = cv2.resize(grid_mask,(600,600), cv2.INTER_AREA)
56
+ # Extract number
57
+ number_img = cv2.bitwise_and(cv2.resize(warped_processed, (600,600), cv2.INTER_AREA), grid_mask)
58
+ # number_img = cv2.bitwise_and(warped_processed, grid_mask)
59
+ # Split into squares
60
+ squares = split_squares(number_img)
61
+ cleaned_squares = clean_square_all_images(squares)
62
+
63
+ # Resize and scale pixel
64
+ resized_list = resize_square(cleaned_squares)
65
+ norm_resized = normalize(resized_list)
66
+
67
+ # # Recognize digits
68
+ rec_str = recognize_digits(classifier, norm_resized, org_img)
69
+ board = convert_str_to_board(rec_str)
70
+
71
+ # Solve
72
+ unsolved_board = board.copy()
73
+ sudoku = Sudoku_solver(board, 9)
74
+ start_time = time.time()
75
+ sudoku.solve()
76
+ solved_board = sudoku.board
77
+
78
+ # Unwarp
79
+ _, warp_with_nums = draw_digits_on_warped(warped, solved_board, unsolved_board)
80
+ final_img = unwarp_image(warp_with_nums, corners_img, corners_list, time.time() - start_time)
81
+ cv2.imshow("Result", final_img)
82
+ if valid_board(solved_board):
83
+ cv2.waitKey(1000)
84
+ else:
85
+ cv2.imshow("Result", final_img)
86
+ if cv2.waitKey(1) & 0xFF == ord('q'):
87
+ break
88
+ cv2.destroyAllWindows()
89
+ cap.release()
sudoku_solve.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Sudoku_solver():
5
+ """
6
+ Solve Sudoku using Backtracking algorithm
7
+ """
8
+
9
+ def __init__(self, board, size):
10
+ self.board = board
11
+ self.size = size
12
+
13
+ def print_board(self):
14
+ """
15
+ Visualize result board
16
+ """
17
+ for i in range(len(self.board)):
18
+ if i % 3 == 0 and i != 0:
19
+ print("- - - - - - - - - - - - - ")
20
+
21
+ for j in range(len(self.board[0])):
22
+ if j % 3 == 0 and j != 0:
23
+ print(" | ", end="")
24
+
25
+ if j == 8:
26
+ print(self.board[i][j])
27
+ else:
28
+ print(str(self.board[i][j]) + " ", end="")
29
+
30
+
31
+ def valid(self, num, pos):
32
+ """
33
+ Check valid board when adding new num in position pos
34
+ """
35
+
36
+ # Check valid row
37
+ for j in range(len(self.board[0])):
38
+ if self.board[pos[0]][j] == num and pos[1] != j:
39
+ return False
40
+
41
+ # Check valid column
42
+ for i in range(len(self.board)):
43
+ if self.board[i][pos[1]] == num and pos[0] != i:
44
+ return False
45
+
46
+ # Check valid box
47
+ # There are 9 boxes
48
+
49
+ box_x = pos[0] // 3
50
+ box_y = pos[1] // 3
51
+
52
+ for i in range(box_x*3, box_x*3+3):
53
+ for j in range(box_y*3, box_y*3+3):
54
+ if self.board[i][j] == num and (i, j) != pos:
55
+ return False
56
+ return True
57
+
58
+
59
+ def find_empty_cell(self):
60
+ """
61
+ Find empty cell and return its position
62
+ """
63
+ for i in range(len(self.board)):
64
+ for j in range(len(self.board[0])):
65
+ if self.board[i][j] == 0:
66
+ return (i, j)
67
+ return None
68
+
69
+ def solve(self):
70
+ pos = self.find_empty_cell()
71
+ # Base case, complete the board
72
+ if not pos:
73
+ return True
74
+ else:
75
+ row, col = pos
76
+
77
+ for i in range(1, 10):
78
+ if self.valid(i, (row, col)):
79
+ self.board[row][col] = i
80
+
81
+ if self.solve():
82
+ return True
83
+
84
+ self.board[row][col] = 0
85
+ return False
threshold.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def preprocess(img):
6
+ """
7
+ Input: Original image
8
+ Output: Gray-scale processed image
9
+ """
10
+ # convert RGB to gray-scale
11
+ if (np.array(img).shape[2] != 1):
12
+ gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
13
+ #Gassian blur
14
+ blured = cv2.GaussianBlur(gray_img, (9,9), 0)
15
+ #set a threshold
16
+ thresh = cv2.adaptiveThreshold(blured, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
17
+ #invert so that the grid line and text are line, the rest is black
18
+ inverted = cv2.bitwise_not(thresh, 0)
19
+ morphy_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2))
20
+ # Opening morphology to remove noise (while dot etc...)
21
+ morph = cv2.morphologyEx(inverted, cv2.MORPH_OPEN, morphy_kernel)
22
+ # dilate to increase border size
23
+ result = cv2.dilate(morph, morphy_kernel, iterations=1)
24
+ return result
25
+
26
+
27
+ if __name__ == "__main__":
28
+ img = "testimg\sudoku_real_4.jpeg"
29
+ img = cv2.imread(img)
30
+ processed = preprocess(img)
31
+ cv2.imshow("img", cv2.resize(img, (600,600), cv2.INTER_AREA))
32
+ cv2.waitKey(0)
utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import operator
4
+ import torch
5
+
6
+
7
+
8
+ def find_corners(polygon, limit_func, compare_func):
9
+ """
10
+ Input: Rectangle puzzle extract from contours
11
+ Output: One of four cornet point depend on limit_func, compare_func
12
+ # limit_fn is the min or max function
13
+ # compare_fn is the np.add or np.subtract function
14
+ Note: (0,0) point is at the top-left
15
+
16
+ top-left: (x+y) min
17
+ top-right: (x-y) max
18
+ bot-left: (x-y) min
19
+ bot-right: (x+y) max
20
+ """
21
+
22
+ index, _ = limit_func(enumerate([compare_func(ptr[0][0], ptr[0][1]) for ptr in polygon]), key = operator.itemgetter(1))
23
+
24
+ return polygon[index][0][0], polygon[index][0][1]
25
+
26
+
27
+
28
+ def draw_circle_at_corners(original, ptr):
29
+ """
30
+ Helper function to draw circle at corners
31
+ """
32
+
33
+ cv2.circle(original, ptr, 5, (0,255,0), cv2.FILLED)
34
+
35
+
36
+
37
+ def grid_line_helper(img, shape_location, length = 10):
38
+ """
39
+ Helper function to fine vertical, horizontal line
40
+ Find horizontal line: shape_location = 0
41
+ Find vertical line: shape_location = 1
42
+ """
43
+
44
+ clone_img = img.copy()
45
+ row_or_col = clone_img.shape[shape_location]
46
+
47
+ # Find the distance between lines
48
+ size = row_or_col // length
49
+
50
+ # Morphological transfromation to find line
51
+
52
+ # Define morphology kernel
53
+ if shape_location == 0:
54
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (size,1))
55
+ else:
56
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,size))
57
+
58
+ clone_img = cv2.erode(clone_img, kernel)
59
+ clone_img = cv2.dilate(clone_img, kernel)
60
+
61
+ return clone_img
62
+
63
+
64
+
65
+ def draw_line(img, lines):
66
+ """
67
+ Draw all lines in lines got from cv2.HoughLine()
68
+ """
69
+ clone_img = img.copy()
70
+ # lines list from cv2.HoughLine() is 3d array
71
+ # Convert to 2d array
72
+
73
+ lines = np.squeeze(lines)
74
+ for rho, theta in lines:
75
+ a = np.cos(theta)
76
+ b = np.sin(theta)
77
+ x0 = a*rho
78
+ y0 = b*rho
79
+ x1 = int(x0 + 1000 * (-b))
80
+ y1 = int(y0 + 1000 * a)
81
+ x2 = int(x0 - 1000 * (-b))
82
+ y2 = int(y0 - 1000 * a)
83
+ #Draw line every loop
84
+ cv2.line(clone_img, (x1,y1), (x2,y2), (255,255,255), 4)
85
+ return clone_img
86
+
87
+
88
+
89
+ def clean_square_helper(img):
90
+ """
91
+ Clean noises in every square splited
92
+ Input: One of 81 squares
93
+ Output: Cleaned square and boolean var which so that there is number in it
94
+ """
95
+
96
+ if np.isclose(img, 0).sum() / (img.shape[0] * img.shape[1]) >= 0.96:
97
+ return np.zeros_like(img), False
98
+
99
+ # if there is very little white in the region around the center, this means we got an edge accidently
100
+ height, width = img.shape
101
+ mid = width // 2
102
+ if np.isclose(img[:, int(mid - width * 0.38):int(mid + width * 0.38)], 0).sum() / (2 * width * 0.38 * height) >= 0.98:
103
+ return np.zeros_like(img), False
104
+
105
+ # center image
106
+ contours, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
107
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
108
+ x, y, w, h = cv2.boundingRect(contours[0])
109
+
110
+ start_x = (width - w) // 2
111
+ start_y = (height - h) // 2
112
+ new_img = np.zeros_like(img)
113
+ new_img[start_y:start_y + h, start_x:start_x + w] = img[y:y + h, x:x + w]
114
+
115
+ return new_img, True
116
+
117
+
118
+
119
+ def resize_square(clean_square_list):
120
+ """
121
+ Resize clean squares into 28x28 in order to feed to classifier
122
+ """
123
+
124
+ resized_list = []
125
+ for img in clean_square_list:
126
+ resized = cv2.resize(img, (28,28), interpolation=cv2.INTER_AREA)
127
+ resized_list.append(resized)
128
+ return resized_list
129
+
130
+
131
+
132
+ def resize_square32(clean_square_list):
133
+ """
134
+ Resize clean squares into 32x32 in order to feed to tf classifier
135
+ """
136
+ resized_list = []
137
+ for img in clean_square_list:
138
+ resized = cv2.resize(img, (32,32), interpolation=cv2.INTER_AREA)
139
+ resized_list.append(resized)
140
+ return resized_list
141
+
142
+
143
+ def classify_one_digit(model, resize_square, org_image):
144
+ """
145
+ Determine whether each square has number by counting number of (not black) pixel and compare to threshold value
146
+ Using classifier to predict number in square
147
+ - Return 0 if the square is blank
148
+ - Return predict digit if the square has number
149
+ """
150
+ threshold = 0
151
+ if (org_image.shape[0] > 600 or org_image.shape[1] > 600) or (org_image.shape[1] > 600 or org_image.shape[2] > 600):
152
+ threshold = 40
153
+ else:
154
+ threshold = 60
155
+ # Determine blank square
156
+
157
+ if (resize_square != resize_square.min()).sum() < threshold:
158
+ return str(0)
159
+
160
+ model.eval()
161
+ # Convert to shape (1,1,28,28) to be compatible with dataloader for evaluation
162
+ iin = torch.Tensor(resize_square).unsqueeze(0).unsqueeze(0)
163
+
164
+ with torch.no_grad():
165
+ out = model(iin)
166
+ # Get index of predict digit
167
+ _, index = torch.max(out, 1)
168
+
169
+ pred_digit = index.item()
170
+
171
+ return str(pred_digit)
172
+
173
+
174
+
175
+ def normalize(resized_list):
176
+ """
177
+ Scale pixel value for recognition
178
+ """
179
+
180
+ return [img/255 for img in resized_list]
181
+
182
+
183
+
184
+ def convert_str_to_board(string, step = 9):
185
+ """
186
+ Convert recognized string into 2D array for sudoku solving
187
+ """
188
+
189
+ board = []
190
+ for i in range(0, len(string), step):
191
+ board.append([int(char) for char in string[i:i+step]])
192
+ return np.array(board)