Spaces:
Runtime error
Runtime error
code
Browse files- create_dataset/create_fontstyle.py +34 -0
- create_dataset/data_generation_sample.py +24 -0
- create_dataset/digital_mnist_digits.py +74 -0
- helper_number_page.py +132 -0
- image_solver.py +72 -0
- model/evaluation.py +64 -0
- model/get_model.py +50 -0
- model/train_classifier.py +110 -0
- processing.py +241 -0
- realtime_solver.py +89 -0
- sudoku_solve.py +85 -0
- threshold.py +32 -0
- utils.py +192 -0
create_dataset/create_fontstyle.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import random
|
| 3 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
+
|
| 5 |
+
font_folder = 'font'
|
| 6 |
+
font_name = ['arial', 'bodoni','calibri','futura','heveltica','times-new-roman']
|
| 7 |
+
|
| 8 |
+
def fontstyle_list(font_folder, font_name):
|
| 9 |
+
font_list = []
|
| 10 |
+
for i in font_name:
|
| 11 |
+
font_dir = glob.glob(font_folder + "\\"+ i +"\\*.ttf")
|
| 12 |
+
for j in font_dir:
|
| 13 |
+
font_list.append(j)
|
| 14 |
+
return font_list
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def draw_img(label, font_list):
|
| 19 |
+
img = Image.new('L', (256, 256))
|
| 20 |
+
size = random.randint(150, 250)
|
| 21 |
+
x = random.randint(60, 90)
|
| 22 |
+
y = random.randint(30, 60)
|
| 23 |
+
draw = ImageDraw.Draw(img)
|
| 24 |
+
# font = ImageFont.truetype(, )
|
| 25 |
+
font = ImageFont.truetype(font_list[0], size)
|
| 26 |
+
draw.text((x, y), str(label), (200),font=font)
|
| 27 |
+
|
| 28 |
+
img = img.resize((28, 28), Image.BILINEAR)
|
| 29 |
+
return img, label
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
fonts = fontstyle_list(font_folder, font_name)
|
| 33 |
+
print(fonts)
|
| 34 |
+
print(len(fonts))
|
create_dataset/data_generation_sample.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageFont, ImageDraw
|
| 2 |
+
import random
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
#font size
|
| 7 |
+
fonts = ['font\Arial\Arial.ttf', 'font\calibri\Calibri-Bold.ttf','font\heveltica\helvetica_bold.ttf','font\Times-new-roman\Times.ttf']
|
| 8 |
+
|
| 9 |
+
img = Image.new('L', (256, 256))
|
| 10 |
+
|
| 11 |
+
target = random.randint(0, 9)
|
| 12 |
+
|
| 13 |
+
size = random.randint(150, 250)
|
| 14 |
+
x = random.randint(60, 90)
|
| 15 |
+
y = random.randint(30, 60)
|
| 16 |
+
draw = ImageDraw.Draw(img)
|
| 17 |
+
# font = ImageFont.truetype(, )
|
| 18 |
+
font = ImageFont.truetype(fonts[random.randint(0,3)], size)
|
| 19 |
+
draw.text((x, y), str(target), (200),font=font)
|
| 20 |
+
|
| 21 |
+
img = img.resize((28, 28), Image.BILINEAR)
|
| 22 |
+
img = np.array(img)
|
| 23 |
+
cv2.imshow("Image", img)
|
| 24 |
+
cv2.waitKey(0)
|
create_dataset/digital_mnist_digits.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
from create_fontstyle import fontstyle_list
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from PIL import ImageFont
|
| 7 |
+
from PIL import ImageDraw
|
| 8 |
+
|
| 9 |
+
import glob
|
| 10 |
+
import random
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
font_folder = 'font'
|
| 14 |
+
font_name = ['arial', 'bodoni','calibri','futura','heveltica','times-new-roman']
|
| 15 |
+
|
| 16 |
+
fonts = fontstyle_list(font_folder, font_name)
|
| 17 |
+
|
| 18 |
+
class PrintedMNIST(Dataset):
|
| 19 |
+
"""
|
| 20 |
+
Generate digital mnist dataset for digits recognition
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, samples, random_state, transform = None):
|
| 23 |
+
self.samples = samples
|
| 24 |
+
self.random_state = random_state
|
| 25 |
+
self.transfrom = transform
|
| 26 |
+
self.fonts = fonts
|
| 27 |
+
|
| 28 |
+
random.seed(random_state)
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return self.samples
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, index):
|
| 33 |
+
color = random.randint(200,255)
|
| 34 |
+
#Generate image
|
| 35 |
+
img = Image.new("L",(256, 256))
|
| 36 |
+
label = random.randint(0,9)
|
| 37 |
+
size = random.randint(180, 220)
|
| 38 |
+
x = random.randint(60, 80)
|
| 39 |
+
y = random.randint(30, 60)
|
| 40 |
+
|
| 41 |
+
draw = ImageDraw.Draw(img)
|
| 42 |
+
#Choose random font style in font style list
|
| 43 |
+
font = ImageFont.truetype(random.choice(self.fonts), size)
|
| 44 |
+
draw.text((x,y), str(label), color, font = fonts)
|
| 45 |
+
|
| 46 |
+
img = img.resize((28,28), Image.BILINEAR)
|
| 47 |
+
if self.transfrom:
|
| 48 |
+
img = self.transfrom(img)
|
| 49 |
+
return img, label
|
| 50 |
+
|
| 51 |
+
class AddSPNoise(object):
|
| 52 |
+
def __init__(self, prob):
|
| 53 |
+
self.prob = prob
|
| 54 |
+
|
| 55 |
+
def __call__(self, tensor):
|
| 56 |
+
sp = (torch.rand(tensor.size()) < self.prob) * tensor.max()
|
| 57 |
+
return tensor + sp
|
| 58 |
+
|
| 59 |
+
def __repr__(self):
|
| 60 |
+
return self.__class__.__name__ + "(prob={0})".format(self.prob)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AddGaussianNoise(object):
|
| 64 |
+
def __init__(self, mean=0.0, std=1.0):
|
| 65 |
+
self.mean = mean
|
| 66 |
+
self.std = std
|
| 67 |
+
|
| 68 |
+
def __call__(self, tensor):
|
| 69 |
+
return tensor + torch.randn(tensor.size()) * self.std + self.mean
|
| 70 |
+
|
| 71 |
+
def __repr__(self):
|
| 72 |
+
return self.__class__.__name__ + "(mean={0}, std={1})".format(
|
| 73 |
+
self.mean, self.std
|
| 74 |
+
)
|
helper_number_page.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
from processing import *
|
| 4 |
+
from utils import *
|
| 5 |
+
from sudoku_solve import Sudoku_solver
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
input_str = "000000000008236400010050020500000009100000007080000050005000200000807000000020000"
|
| 9 |
+
input_str2 = "800010009050807010004090700060701020508060107010502090007040600080309040300050008"
|
| 10 |
+
|
| 11 |
+
def draw_grid():
|
| 12 |
+
base_img = 1* np.ones((600,600,3))
|
| 13 |
+
width = base_img.shape[0] // 9
|
| 14 |
+
cv2.rectangle(base_img, (0,0), (base_img.shape[0], base_img.shape[1]), (0,0,0), 10)
|
| 15 |
+
for i in range(1,10):
|
| 16 |
+
if i % 3 == 0:
|
| 17 |
+
cv2.line(base_img, (i*width, 0), (i*width, base_img.shape[1]), (0,0,0), 6)
|
| 18 |
+
cv2.line(base_img, (0, i* width), (base_img.shape[0], i*width), (0,0,0), 6)
|
| 19 |
+
else:
|
| 20 |
+
cv2.line(base_img, (i*width, 0), (i*width, base_img.shape[1]), (0,0,0), 2)
|
| 21 |
+
cv2.line(base_img, (0, i* width), (base_img.shape[0], i*width), (0,0,0), 2)
|
| 22 |
+
return base_img
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def draw_digit(base_img, input_str):
|
| 26 |
+
width = base_img.shape[0] // 9
|
| 27 |
+
board = convert_str_to_board(input_str)
|
| 28 |
+
for j in range(9):
|
| 29 |
+
for i in range(9):
|
| 30 |
+
if board[j][i] !=0 : # Only draw new number to blank cell in warped image, avoid overlapping
|
| 31 |
+
|
| 32 |
+
p1 = (i * width, j * width) # Top left corner of a bounding box
|
| 33 |
+
p2 = ((i + 1) * width, (j + 1) * width) # Bottom right corner of bounding box
|
| 34 |
+
|
| 35 |
+
# Find the center of square to draw digit
|
| 36 |
+
center = ((p1[0] + p2[0]) // 2, (p1[1] + p2[1]) // 2)
|
| 37 |
+
text_size, _ = cv2.getTextSize(str(board[j][i]), cv2.FONT_HERSHEY_SIMPLEX, 1, 6)
|
| 38 |
+
text_origin = (center[0] - text_size[0] // 2, center[1] + text_size[1] // 2)
|
| 39 |
+
|
| 40 |
+
cv2.putText(base_img, str(board[j][i]),
|
| 41 |
+
text_origin, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 0), 6)
|
| 42 |
+
return base_img, board
|
| 43 |
+
|
| 44 |
+
def solve(board):
|
| 45 |
+
unsolved_board = board.copy()
|
| 46 |
+
sudoku = Sudoku_solver(board, 9)
|
| 47 |
+
sudoku.solve()
|
| 48 |
+
res_board = sudoku.board
|
| 49 |
+
return res_board, unsolved_board
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def draw_result(base_img, unsolved_board, solved_board):
|
| 53 |
+
width = base_img.shape[0] // 9
|
| 54 |
+
for j in range(9):
|
| 55 |
+
for i in range(9):
|
| 56 |
+
p1 = (i * width, j * width) # Top left corner of a bounding box
|
| 57 |
+
p2 = ((i + 1) * width, (j + 1) * width) # Bottom right corner of bounding box
|
| 58 |
+
|
| 59 |
+
# Find the center of square to draw digit
|
| 60 |
+
center = ((p1[0] + p2[0]) // 2, (p1[1] + p2[1]) // 2)
|
| 61 |
+
text_size, _ = cv2.getTextSize(str(solved_board[j][i]), cv2.FONT_HERSHEY_SIMPLEX, 1, 6)
|
| 62 |
+
text_origin = (center[0] - text_size[0] // 2, center[1] + text_size[1] // 2)
|
| 63 |
+
if unsolved_board[j][i] != solved_board[j][i]:
|
| 64 |
+
cv2.putText(base_img, str(solved_board[j][i]),
|
| 65 |
+
text_origin, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 6)
|
| 66 |
+
else:
|
| 67 |
+
cv2.putText(base_img, str(solved_board[j][i]),
|
| 68 |
+
text_origin, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 0), 6)
|
| 69 |
+
return base_img
|
| 70 |
+
|
| 71 |
+
### CHECK VALID SODOKU PUZZLE INPUT FROM USER
|
| 72 |
+
def get_column(board, index):
|
| 73 |
+
return np.array([row[index] for row in board])
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def valid_row_or_col(array):
|
| 77 |
+
if np.all(array == 0) == True:
|
| 78 |
+
return True
|
| 79 |
+
return len(set(array[array!=0])) == len(list(array[array!=0]))
|
| 80 |
+
|
| 81 |
+
def valid_single_box(board, box_x, box_y):
|
| 82 |
+
box = board[box_x*3 : box_x*3 + 3, box_y*3: box_y*3+3]
|
| 83 |
+
if len(list(box[box!=0])) == 0:
|
| 84 |
+
return True
|
| 85 |
+
return len(set(box[box!=0])) == len(list(box[box!=0]))
|
| 86 |
+
|
| 87 |
+
def valid_input_str(input_str):
|
| 88 |
+
board = convert_str_to_board(input_str)
|
| 89 |
+
# Check valid row
|
| 90 |
+
for i in range(0,len(board)):
|
| 91 |
+
if valid_row_or_col(board[i]) == False:
|
| 92 |
+
return False
|
| 93 |
+
# Check valid column
|
| 94 |
+
for j in range(0, len(board[0])):
|
| 95 |
+
if valid_row_or_col(get_column(board, j)) == False:
|
| 96 |
+
return False
|
| 97 |
+
# Check valid box
|
| 98 |
+
for i in range(0, 3):
|
| 99 |
+
for j in range(0, 3):
|
| 100 |
+
if valid_single_box(board, i, j) == False:
|
| 101 |
+
return False
|
| 102 |
+
return True
|
| 103 |
+
|
| 104 |
+
def valid_board(board):
|
| 105 |
+
# Check valid row
|
| 106 |
+
for i in range(0,len(board)):
|
| 107 |
+
if valid_row_or_col(board[i]) == False:
|
| 108 |
+
return False
|
| 109 |
+
# Check valid column
|
| 110 |
+
for j in range(0, len(board[0])):
|
| 111 |
+
if valid_row_or_col(get_column(board, j)) == False:
|
| 112 |
+
return False
|
| 113 |
+
# Check valid box
|
| 114 |
+
for i in range(0, 3):
|
| 115 |
+
for j in range(0, 3):
|
| 116 |
+
if valid_single_box(board, i, j) == False:
|
| 117 |
+
return False
|
| 118 |
+
return True
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
base_img = draw_grid()
|
| 122 |
+
res_img = base_img.copy()
|
| 123 |
+
base_img, board = draw_digit(base_img, input_str)
|
| 124 |
+
cv2.imshow("IMG", base_img)
|
| 125 |
+
cv2.waitKey(0)
|
| 126 |
+
|
| 127 |
+
res_board, unsolved_board = solve(board)
|
| 128 |
+
res_img = draw_result(res_img, unsolved_board, res_board)
|
| 129 |
+
cv2.imshow("Show result", res_img)
|
| 130 |
+
cv2.waitKey(0)
|
| 131 |
+
res = valid_input_str(input_str2)
|
| 132 |
+
print(res)
|
image_solver.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from utils import *
|
| 4 |
+
from processing import *
|
| 5 |
+
from threshold import preprocess
|
| 6 |
+
import time
|
| 7 |
+
import cv2
|
| 8 |
+
from sudoku_solve import Sudoku_solver
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
# This module performs sudoku solver which input is a image file.
|
| 12 |
+
|
| 13 |
+
classifier = torch.load('digit_model.h5',map_location ='cpu')
|
| 14 |
+
classifier.eval()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def image_solver(img, model):
|
| 18 |
+
original_img = img.copy()
|
| 19 |
+
threshold = preprocess(img)
|
| 20 |
+
corners_img, corners_list, org_img = find_contours(threshold, img)
|
| 21 |
+
try:
|
| 22 |
+
# Warped original img
|
| 23 |
+
warped, matrix = warp_image(corners_list, corners_img)
|
| 24 |
+
# Threshold warped img
|
| 25 |
+
warped_processed = preprocess(warped) # warped_processed is gray-scaled img
|
| 26 |
+
|
| 27 |
+
#Get lines
|
| 28 |
+
horizontal = grid_line_helper(warped_processed, shape_location=0)
|
| 29 |
+
vertical = grid_line_helper(warped_processed, shape_location=1)
|
| 30 |
+
|
| 31 |
+
# Create mask
|
| 32 |
+
if img.shape[0] > 600 or img.shape[1] > 600:
|
| 33 |
+
# Resize will get better result ??
|
| 34 |
+
grid_mask = create_grid_mask(horizontal, vertical)
|
| 35 |
+
grid_mask = cv2.resize(grid_mask,(600,600), cv2.INTER_AREA)
|
| 36 |
+
number_img = cv2.bitwise_and(cv2.resize(warped_processed, (600,600), cv2.INTER_AREA), grid_mask)
|
| 37 |
+
else:
|
| 38 |
+
grid_mask = create_grid_mask(horizontal, vertical)
|
| 39 |
+
# Extract number
|
| 40 |
+
number_img = cv2.bitwise_and(warped_processed, grid_mask)
|
| 41 |
+
# Split into squares
|
| 42 |
+
squares = split_squares(number_img)
|
| 43 |
+
cleaned_squares = clean_square_all_images(squares)
|
| 44 |
+
|
| 45 |
+
# Resize and scale pixel
|
| 46 |
+
resized_list = resize_square(cleaned_squares)
|
| 47 |
+
norm_resized = normalize(resized_list)
|
| 48 |
+
|
| 49 |
+
# # Recognize digits
|
| 50 |
+
rec_str = recognize_digits(model, norm_resized, original_img)
|
| 51 |
+
board = convert_str_to_board(rec_str)
|
| 52 |
+
|
| 53 |
+
# Solve
|
| 54 |
+
unsolved_board = board.copy()
|
| 55 |
+
sudoku = Sudoku_solver(board, 9)
|
| 56 |
+
start_time = time.time()
|
| 57 |
+
sudoku.solve()
|
| 58 |
+
solved_board = sudoku.board
|
| 59 |
+
# Unwarp
|
| 60 |
+
_, warp_with_nums = draw_digits_on_warped(warped, solved_board, unsolved_board)
|
| 61 |
+
|
| 62 |
+
dst_img = unwarp_image(warp_with_nums, corners_img, corners_list, time.time() - start_time)
|
| 63 |
+
return dst_img, solved_board
|
| 64 |
+
except TypeError:
|
| 65 |
+
print("Can not warp image. Please try another image")
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
url = "streamlit_app\image_from_user\Test40.jpg" # Url for test image
|
| 69 |
+
res, solved_board = image_solver(url, classifier)
|
| 70 |
+
cv2.imshow("Result", cv2.resize(res, (700,700), cv2.INTER_AREA))
|
| 71 |
+
cv2.waitKey(0)
|
| 72 |
+
|
model/evaluation.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageFont, ImageDraw
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from train_classifier import model, device, train_loader, val_loader, train_acc, train_loss, val_loss, val_acc
|
| 8 |
+
|
| 9 |
+
# Just visualize model results
|
| 10 |
+
|
| 11 |
+
def visualize_sample():
|
| 12 |
+
"""
|
| 13 |
+
Visualize sample in dataloader
|
| 14 |
+
"""
|
| 15 |
+
train_features, train_labels = next(iter(train_loader))
|
| 16 |
+
print(f"Feature batch shape: {train_features.size()}")
|
| 17 |
+
print(f"Labels batch shape: {train_labels.size()}")
|
| 18 |
+
img = train_features[0].squeeze()
|
| 19 |
+
label = train_labels[0]
|
| 20 |
+
plt.imshow(img, cmap="gray")
|
| 21 |
+
plt.show()
|
| 22 |
+
print(f"Label: {label}")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def plot_metrics(train_loss, train_acc, val_loss, val_acc):
|
| 27 |
+
fig, ag = plt.subplots(1,2,figsize = (15,6))
|
| 28 |
+
ag[0].plot(train_loss,label = 'train')
|
| 29 |
+
ag[0].plot(val_loss,label = 'val')
|
| 30 |
+
ag[0].legend()
|
| 31 |
+
ag[0].set_title('Loss versus epochs')
|
| 32 |
+
|
| 33 |
+
ag[1].plot(train_acc,label='train')
|
| 34 |
+
ag[1].plot(val_acc,label='test')
|
| 35 |
+
ag[1].legend()
|
| 36 |
+
ag[1].set_title('Accuracy versus epochs')
|
| 37 |
+
plt.show()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def predict_batch(model, data_loader):
|
| 41 |
+
"""
|
| 42 |
+
Get prediction on one random batch
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
batch_id = np.random.randint(0, len(data_loader))
|
| 46 |
+
for index, batch in enumerate(data_loader):
|
| 47 |
+
if index == batch_id:
|
| 48 |
+
inputs, labels = batch[0], batch[1]
|
| 49 |
+
model = model.to(device)
|
| 50 |
+
inputs = inputs.to(device)
|
| 51 |
+
outputs = model(inputs)
|
| 52 |
+
preds = outputs.argmax(dim=1)
|
| 53 |
+
preds=preds.cpu().numpy()
|
| 54 |
+
labels=labels.numpy()
|
| 55 |
+
return inputs, preds, labels
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
visualize_sample()
|
| 60 |
+
plot_metrics(train_loss, train_acc, val_loss, val_acc)
|
| 61 |
+
inputs, preds, labels = predict_batch(model, val_loader)
|
| 62 |
+
print(preds)
|
| 63 |
+
print(labels)
|
| 64 |
+
print("Accuracy on random batch: {}/{}".format(np.sum(preds==labels), len(preds)))
|
model/get_model.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.models import resnet18, resnet101, resnet50
|
| 3 |
+
import torchvision
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_model(model_name, pretrained=True):
|
| 8 |
+
if model_name == "resnet18":
|
| 9 |
+
net = torchvision.models.resnet18(pretrained=pretrained)
|
| 10 |
+
|
| 11 |
+
# Replace 1st layer to use it on grayscale images
|
| 12 |
+
net.conv1 = nn.Conv2d(
|
| 13 |
+
1,
|
| 14 |
+
64,
|
| 15 |
+
kernel_size=(7, 7),
|
| 16 |
+
stride=(2, 2),
|
| 17 |
+
padding=(3, 3),
|
| 18 |
+
bias=False,
|
| 19 |
+
)
|
| 20 |
+
net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
|
| 21 |
+
if model_name == "resnet50":
|
| 22 |
+
|
| 23 |
+
net = torchvision.models.resnet50(pretrained=pretrained)
|
| 24 |
+
|
| 25 |
+
# Replace 1st layer to use it on grayscale images
|
| 26 |
+
net.conv1 = nn.Conv2d(
|
| 27 |
+
1,
|
| 28 |
+
64,
|
| 29 |
+
kernel_size=(7, 7),
|
| 30 |
+
stride=(2, 2),
|
| 31 |
+
padding=(3, 3),
|
| 32 |
+
bias=False,
|
| 33 |
+
)
|
| 34 |
+
net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
|
| 35 |
+
if model_name == "resnet101":
|
| 36 |
+
|
| 37 |
+
net = torchvision.models.resnet101(pretrained=pretrained)
|
| 38 |
+
|
| 39 |
+
# Replace 1st layer to use it on grayscale images
|
| 40 |
+
net.conv1 = nn.Conv2d(
|
| 41 |
+
1,
|
| 42 |
+
64,
|
| 43 |
+
kernel_size=(7, 7),
|
| 44 |
+
stride=(2, 2),
|
| 45 |
+
padding=(3, 3),
|
| 46 |
+
bias=False,
|
| 47 |
+
)
|
| 48 |
+
net.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
|
| 49 |
+
|
| 50 |
+
return net
|
model/train_classifier.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageFont, ImageDraw
|
| 2 |
+
import random
|
| 3 |
+
import glob
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import torch
|
| 10 |
+
import time
|
| 11 |
+
import copy
|
| 12 |
+
from create_dataset.digital_mnist_digits import PrintedMNIST
|
| 13 |
+
from get_model import get_model
|
| 14 |
+
|
| 15 |
+
# Define parameters
|
| 16 |
+
batch_size = 64
|
| 17 |
+
net = get_model("resnet50")
|
| 18 |
+
n_epochs = 10
|
| 19 |
+
device = "cuda" if torch.cuda.is_available() == True else "cpu"
|
| 20 |
+
|
| 21 |
+
#Define optimizer
|
| 22 |
+
learning_rate = 1e-3
|
| 23 |
+
criterion = nn.CrossEntropyLoss()
|
| 24 |
+
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_dataset(batch_size):
|
| 28 |
+
"""
|
| 29 |
+
Load dataset using Pytorch Dataloader
|
| 30 |
+
"""
|
| 31 |
+
train_transform = transforms.Compose([
|
| 32 |
+
transforms.RandomRotation(10),
|
| 33 |
+
transforms.ToTensor(),
|
| 34 |
+
# AddGaussianNoise(0, 1.0),
|
| 35 |
+
# AddSPNoise(0.1),
|
| 36 |
+
|
| 37 |
+
])
|
| 38 |
+
val_transforms = transforms.Compose([transforms.ToTensor()])
|
| 39 |
+
|
| 40 |
+
train_set = PrintedMNIST(50000, 42, train_transform)
|
| 41 |
+
val_set = PrintedMNIST(5000, 33, val_transforms)
|
| 42 |
+
|
| 43 |
+
train_loader = DataLoader(train_set, batch_size=batch_size)
|
| 44 |
+
val_loader = DataLoader(val_set, batch_size=batch_size)
|
| 45 |
+
return train_loader, val_loader
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs):
|
| 49 |
+
# since = time.time()
|
| 50 |
+
best_model_wts = copy.deepcopy(model.state_dict())
|
| 51 |
+
best_acc = 0.0
|
| 52 |
+
train_loss_track = []
|
| 53 |
+
train_acc_track = []
|
| 54 |
+
val_loss_track = []
|
| 55 |
+
val_acc_track = []
|
| 56 |
+
for epoch in range(num_epochs):
|
| 57 |
+
print(f'Epoch {epoch + 1}/{num_epochs}')
|
| 58 |
+
print('-' * 10)
|
| 59 |
+
|
| 60 |
+
# Training loop
|
| 61 |
+
train_loss, train_correct = 0, 0
|
| 62 |
+
model.train()
|
| 63 |
+
for batch in train_loader:
|
| 64 |
+
images, labels = batch[0].to(device), batch[1].to(device) # load the batch to the available device (cpu/gpu)
|
| 65 |
+
outputs = model(images)
|
| 66 |
+
loss = criterion(outputs, labels)
|
| 67 |
+
preds = outputs.argmax(dim=1).cpu().numpy()
|
| 68 |
+
optimizer.zero_grad()
|
| 69 |
+
loss.backward()
|
| 70 |
+
optimizer.step()
|
| 71 |
+
np_labels_train = labels.cpu().numpy()
|
| 72 |
+
train_loss += loss.item() * batch_size
|
| 73 |
+
train_correct += np.sum(preds == np_labels_train)
|
| 74 |
+
train_loss_avg = train_loss/len(train_loader.sampler)
|
| 75 |
+
train_acc_avg = train_correct/len(train_loader.sampler)
|
| 76 |
+
print('Train Loss: ',train_loss_avg)
|
| 77 |
+
print('Train Accuracy: ', train_acc_avg)
|
| 78 |
+
train_loss_track.append(train_loss_avg)
|
| 79 |
+
train_acc_track.append(train_acc_avg)
|
| 80 |
+
|
| 81 |
+
#Validation loop
|
| 82 |
+
model.eval()
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
valid_loss, valid_correct = 0, 0
|
| 85 |
+
|
| 86 |
+
for batch in val_loader:
|
| 87 |
+
images, labels = batch[0].to(device), batch[1].to(device) # load the batch to the available device
|
| 88 |
+
outputs = model(images)
|
| 89 |
+
loss = criterion(outputs, labels)
|
| 90 |
+
preds = outputs.argmax(dim=1).cpu().numpy()
|
| 91 |
+
np_label_val = labels.cpu().numpy()
|
| 92 |
+
valid_loss += loss.item() * batch_size
|
| 93 |
+
valid_correct += np.sum(preds == np_label_val)
|
| 94 |
+
if valid_correct > best_acc:
|
| 95 |
+
best_acc = valid_correct
|
| 96 |
+
best_model_wts = copy.deepcopy(model.state_dict())
|
| 97 |
+
valid_loss_avg = valid_loss/len(val_loader.sampler)
|
| 98 |
+
valid_acc_avg = valid_correct/len(val_loader.sampler)
|
| 99 |
+
print('Validation Loss: ', valid_loss_avg)
|
| 100 |
+
print('Validation Accuracy: ',valid_acc_avg)
|
| 101 |
+
val_loss_track.append(valid_loss_avg)
|
| 102 |
+
val_acc_track.append(valid_acc_avg)
|
| 103 |
+
|
| 104 |
+
# Return model with best metrics
|
| 105 |
+
model.load_state_dict(best_model_wts)
|
| 106 |
+
return model, train_loss_track, train_acc_track, val_loss_track, val_acc_track
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
train_loader, val_loader = load_dataset(batch_size)
|
| 110 |
+
model, train_loss, train_acc, val_loss, val_acc = train(net, train_loader, val_loader, criterion, optimizer, n_epochs)
|
processing.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from threshold import preprocess
|
| 4 |
+
from utils import find_corners, draw_circle_at_corners, grid_line_helper, draw_line
|
| 5 |
+
from utils import clean_square_helper, classify_one_digit
|
| 6 |
+
|
| 7 |
+
#----------------Process pipe line------------------------------#
|
| 8 |
+
|
| 9 |
+
# 1) Threshold Adaptive to get gray-scale image to find contours
|
| 10 |
+
# 2) Find contours from original image
|
| 11 |
+
# 3) Image alignment (warp image) on original image
|
| 12 |
+
# 4) Get horizontal, vertical line and create grid mask
|
| 13 |
+
# 5) Extract numbers and split gray-scale image into 81 squares
|
| 14 |
+
# 6) Clean noise pixels of each square
|
| 15 |
+
# 7) Recognize digits
|
| 16 |
+
# 8) Solve sudoku
|
| 17 |
+
# 9) Draw solved board on warped image
|
| 18 |
+
# 10) Unwarped image --> Result
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def find_contours(img, original):
|
| 22 |
+
"""
|
| 23 |
+
contours: A tuple of all point creating contour lines, each contour is a np array of points (x,y).
|
| 24 |
+
hierachy: [Next, Previous, First_Child, Parent]
|
| 25 |
+
contour approximation: https://pyimagesearch.com/2021/10/06/opencv-contour-approximation/
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
# find contours on threshold image
|
| 29 |
+
contours, hierachy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 30 |
+
#sort the largest contour to find the puzzle
|
| 31 |
+
contours = sorted(contours, key = cv2.contourArea, reverse = True)
|
| 32 |
+
polygon = None
|
| 33 |
+
# find the largest rectangle-shape contour to make sure this is the puzzle
|
| 34 |
+
for con in contours:
|
| 35 |
+
area = cv2.contourArea(con)
|
| 36 |
+
perimeter = cv2.arcLength(con, closed = True)
|
| 37 |
+
approx = cv2.approxPolyDP(con, epsilon=0.01 * perimeter, closed = True)
|
| 38 |
+
num_of_ptr = len(approx)
|
| 39 |
+
if num_of_ptr == 4 and area > 1000:
|
| 40 |
+
polygon = con #finded puzzle
|
| 41 |
+
break
|
| 42 |
+
if polygon is not None:
|
| 43 |
+
# find corner
|
| 44 |
+
top_left = find_corners(polygon, limit_func= min, compare_func= np.add)
|
| 45 |
+
top_right = find_corners(polygon, limit_func= max, compare_func= np.subtract)
|
| 46 |
+
bot_left = find_corners(polygon,limit_func=min, compare_func= np.subtract)
|
| 47 |
+
bot_right = find_corners(polygon,limit_func=max, compare_func=np.add)
|
| 48 |
+
#Check polygon is square, if not return []
|
| 49 |
+
#Set threshold rate for width and height to determine square bounding box
|
| 50 |
+
if not (0.5 < ((top_right[0]-top_left[0]) / (bot_right[1]-top_right[1]))<1.5):
|
| 51 |
+
print("Exception 1 : Get another image to get square-shape puzzle")
|
| 52 |
+
return [],[],[]
|
| 53 |
+
if bot_right[1] - top_right[1] == 0:
|
| 54 |
+
print("Exception 2 : Get another image to get square-shape puzzle")
|
| 55 |
+
return [],[],[]
|
| 56 |
+
corner_list = [top_left, top_right, bot_right, bot_left]
|
| 57 |
+
draw_original = original.copy()
|
| 58 |
+
cv2.drawContours(draw_original, [polygon], 0, (0,255,0), 3)
|
| 59 |
+
#draw circle at each corner point
|
| 60 |
+
for x in corner_list:
|
| 61 |
+
draw_circle_at_corners(draw_original, x)
|
| 62 |
+
|
| 63 |
+
return draw_original, corner_list, original
|
| 64 |
+
# draw_original: Img which drown contour and corner
|
| 65 |
+
# corner_list: list of 4 corner points
|
| 66 |
+
# original: Original imgs
|
| 67 |
+
print("Can not detect puzzle")
|
| 68 |
+
return [],[],[]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def warp_image(corner_list, original):
|
| 73 |
+
"""
|
| 74 |
+
Input: 4 corner points and threshold grayscale image
|
| 75 |
+
Output: Perspective transformation matrix and transformed image
|
| 76 |
+
Perspective transformation: https://theailearner.com/tag/cv2-warpperspective/
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
corners = np.array(corner_list, dtype= "float32")
|
| 80 |
+
top_left, top_right, bot_left, bot_right = corners[0], corners[1], corners[2], corners[3]
|
| 81 |
+
#Get the largest side to be the side of squared transfromed puzzle
|
| 82 |
+
side = int(max([
|
| 83 |
+
np.linalg.norm(top_right - bot_right),
|
| 84 |
+
np.linalg.norm(top_left - bot_left),
|
| 85 |
+
np.linalg.norm(bot_right - bot_left),
|
| 86 |
+
np.linalg.norm(top_left - top_right)
|
| 87 |
+
]))
|
| 88 |
+
out_ptr = np.array([[0,0],[side-1,0],[side-1,side-1], [0,side-1]],dtype="float32")
|
| 89 |
+
transfrom_matrix = cv2.getPerspectiveTransform(corners, out_ptr)
|
| 90 |
+
transformed_image = cv2.warpPerspective(original, transfrom_matrix, (side, side))
|
| 91 |
+
return transformed_image, transfrom_matrix
|
| 92 |
+
except IndexError:
|
| 93 |
+
print("Can not detect corners")
|
| 94 |
+
except:
|
| 95 |
+
print("Something went wrong. Try another image")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_grid_line(img, length = 10):
|
| 101 |
+
"""
|
| 102 |
+
Get horizontal and vertical lines from warped image
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
horizontal = grid_line_helper(img, shape_location= 1)
|
| 106 |
+
vertical = grid_line_helper(img, shape_location=0)
|
| 107 |
+
return vertical, horizontal
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_grid_mask(horizontal, vertical):
|
| 113 |
+
"""
|
| 114 |
+
Completely detect all lines by using Hough Transformation
|
| 115 |
+
Create grid mask to extract number by using bitwise_and with warped images
|
| 116 |
+
"""
|
| 117 |
+
# combine two line to make a grid
|
| 118 |
+
grid = cv2.add(horizontal, vertical)
|
| 119 |
+
# Apply threshold to cover more area
|
| 120 |
+
|
| 121 |
+
# grid = cv2.adaptiveThreshold(grid, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 235, 2)
|
| 122 |
+
morpho_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
|
| 123 |
+
grid = cv2.dilate(grid, morpho_kernel, iterations=2)
|
| 124 |
+
# find the line by Houghline transfromation
|
| 125 |
+
lines = cv2.HoughLines(grid, 0.3, np.pi/90, 200)
|
| 126 |
+
lines_img = draw_line(grid, lines)
|
| 127 |
+
# Extract all the lines
|
| 128 |
+
mask = cv2.bitwise_not(lines_img)
|
| 129 |
+
return mask
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def split_squares(number_img):
|
| 135 |
+
"""
|
| 136 |
+
Split number img into 81 squares.
|
| 137 |
+
"""
|
| 138 |
+
square_list = []
|
| 139 |
+
side = number_img.shape[0] // 9
|
| 140 |
+
|
| 141 |
+
#find each square and append to square_list
|
| 142 |
+
for j in range(0,9):
|
| 143 |
+
for i in range(0,9):
|
| 144 |
+
top_left_square = (i * side, j * side)
|
| 145 |
+
bot_right_square = ((i+1) * side, (j+1) * side)
|
| 146 |
+
square_list.append(number_img[top_left_square[1]:bot_right_square[1], top_left_square[0]: bot_right_square[0]])
|
| 147 |
+
|
| 148 |
+
return square_list
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def clean_square(square_list):
|
| 154 |
+
"""
|
| 155 |
+
Return cleaned-square list and number of digits available in the image
|
| 156 |
+
Clean-square list has both 0 and images
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
cleaned_squares = []
|
| 160 |
+
count = 0
|
| 161 |
+
|
| 162 |
+
for sq in square_list:
|
| 163 |
+
new_img, is_num = clean_square_helper(sq)
|
| 164 |
+
if is_num:
|
| 165 |
+
cleaned_squares.append(new_img)
|
| 166 |
+
count += 1
|
| 167 |
+
else:
|
| 168 |
+
cleaned_squares.append(0)
|
| 169 |
+
return cleaned_squares, count
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def clean_square_all_images(square_list):
|
| 174 |
+
"""
|
| 175 |
+
Return cleaned-square list
|
| 176 |
+
Clean-square list has all images(images with no number with be black image after cleaning)
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
square_cleaned_list = []
|
| 180 |
+
for i in square_list:
|
| 181 |
+
clean_square, _ = clean_square_helper(i)
|
| 182 |
+
square_cleaned_list.append(clean_square)
|
| 183 |
+
return square_cleaned_list
|
| 184 |
+
|
| 185 |
+
def recognize_digits(model, resized, org_img):
|
| 186 |
+
res_str = ""
|
| 187 |
+
for img in resized:
|
| 188 |
+
digit = classify_one_digit(model, img, org_img)
|
| 189 |
+
res_str += str(digit)
|
| 190 |
+
return res_str
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def draw_digits_on_warped(warped_img, solved_board, unsolved_board):
|
| 195 |
+
"""
|
| 196 |
+
Function to draw digits from solved board to warped img
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
width = warped_img.shape[0] // 9
|
| 200 |
+
|
| 201 |
+
img_w_text = np.zeros_like(warped_img)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
for j in range(9):
|
| 205 |
+
for i in range(9):
|
| 206 |
+
if unsolved_board[j][i] == 0: # Only draw new number to blank cell in warped image, avoid overlapping
|
| 207 |
+
|
| 208 |
+
p1 = (i * width, j * width) # Top left corner of a bounding box
|
| 209 |
+
p2 = ((i + 1) * width, (j + 1) * width) # Bottom right corner of bounding box
|
| 210 |
+
|
| 211 |
+
# Find the center of square to draw digit
|
| 212 |
+
center = ((p1[0] + p2[0]) // 2, (p1[1] + p2[1]) // 2)
|
| 213 |
+
text_size, _ = cv2.getTextSize(str(solved_board[j][i]), cv2.FONT_HERSHEY_SIMPLEX, 1, 6)
|
| 214 |
+
text_origin = (center[0] - text_size[0] // 2, center[1] + text_size[1] // 2)
|
| 215 |
+
|
| 216 |
+
cv2.putText(warped_img, str(solved_board[j][i]),
|
| 217 |
+
text_origin, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 6)
|
| 218 |
+
|
| 219 |
+
return img_w_text, warped_img
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def unwarp_image(img_src, img_dest, pts, time):
|
| 224 |
+
pts = np.array(pts)
|
| 225 |
+
|
| 226 |
+
height, width = img_src.shape[0], img_src.shape[1]
|
| 227 |
+
pts_source = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], [0, width - 1]],
|
| 228 |
+
dtype='float32')
|
| 229 |
+
|
| 230 |
+
matrix, status = cv2.findHomography(pts_source, pts)
|
| 231 |
+
# Covert to original view perspective
|
| 232 |
+
warped = cv2.warpPerspective(img_src, matrix, (img_dest.shape[1], img_dest.shape[0]))
|
| 233 |
+
# Draw a black rectangle in img_dest
|
| 234 |
+
|
| 235 |
+
cv2.fillConvexPoly(img_dest, pts, 0, 16)
|
| 236 |
+
dst_img = cv2.add(img_dest, warped)
|
| 237 |
+
dst_img_height, dst_img_width = dst_img.shape[0], dst_img.shape[1]
|
| 238 |
+
cv2.putText(dst_img, "Time solved: {} s".format(str(np.round(time,4))), (dst_img_width - 360, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
|
| 239 |
+
|
| 240 |
+
return dst_img
|
| 241 |
+
|
realtime_solver.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from utils import *
|
| 4 |
+
from processing import *
|
| 5 |
+
from threshold import preprocess
|
| 6 |
+
import time
|
| 7 |
+
import cv2
|
| 8 |
+
from sudoku_solve import Sudoku_solver
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from helper_number_page import valid_board
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
classifier = torch.load('digit_classifier.h5',map_location ='cpu')
|
| 14 |
+
classifier.eval()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
frameWidth = 960
|
| 18 |
+
frameHeight = 720
|
| 19 |
+
|
| 20 |
+
cap = cv2.VideoCapture(0)
|
| 21 |
+
frame_rate = 60
|
| 22 |
+
|
| 23 |
+
# width is id number 3, height is id 4
|
| 24 |
+
cap.set(3, frameWidth)
|
| 25 |
+
cap.set(4, frameHeight)
|
| 26 |
+
|
| 27 |
+
# change brightness
|
| 28 |
+
cap.set(10, 150)
|
| 29 |
+
prev = 0
|
| 30 |
+
|
| 31 |
+
while cap.isOpened():
|
| 32 |
+
time_elapsed = time.time() - prev
|
| 33 |
+
success, img = cap.read()
|
| 34 |
+
if time_elapsed > 1. / frame_rate:
|
| 35 |
+
prev = time.time()
|
| 36 |
+
final_img = img.copy()
|
| 37 |
+
to_process_img = img.copy()
|
| 38 |
+
#Processing
|
| 39 |
+
thresholded_img = preprocess(to_process_img) # Gray-scale img
|
| 40 |
+
corners_img, corners_list, org_img = find_contours(thresholded_img, to_process_img)
|
| 41 |
+
|
| 42 |
+
if corners_list:
|
| 43 |
+
# Warped original img
|
| 44 |
+
warped, matrix = warp_image(corners_list, corners_img)
|
| 45 |
+
# Threshold warped img
|
| 46 |
+
warped_processed = preprocess(warped) # warped_processed is gray-scaled img
|
| 47 |
+
|
| 48 |
+
#Get lines
|
| 49 |
+
horizontal = grid_line_helper(warped_processed, shape_location=0)
|
| 50 |
+
vertical = grid_line_helper(warped_processed, shape_location=1)
|
| 51 |
+
|
| 52 |
+
# Create mask
|
| 53 |
+
grid_mask = create_grid_mask(horizontal, vertical)
|
| 54 |
+
# Resize will get better result ??
|
| 55 |
+
grid_mask = cv2.resize(grid_mask,(600,600), cv2.INTER_AREA)
|
| 56 |
+
# Extract number
|
| 57 |
+
number_img = cv2.bitwise_and(cv2.resize(warped_processed, (600,600), cv2.INTER_AREA), grid_mask)
|
| 58 |
+
# number_img = cv2.bitwise_and(warped_processed, grid_mask)
|
| 59 |
+
# Split into squares
|
| 60 |
+
squares = split_squares(number_img)
|
| 61 |
+
cleaned_squares = clean_square_all_images(squares)
|
| 62 |
+
|
| 63 |
+
# Resize and scale pixel
|
| 64 |
+
resized_list = resize_square(cleaned_squares)
|
| 65 |
+
norm_resized = normalize(resized_list)
|
| 66 |
+
|
| 67 |
+
# # Recognize digits
|
| 68 |
+
rec_str = recognize_digits(classifier, norm_resized, org_img)
|
| 69 |
+
board = convert_str_to_board(rec_str)
|
| 70 |
+
|
| 71 |
+
# Solve
|
| 72 |
+
unsolved_board = board.copy()
|
| 73 |
+
sudoku = Sudoku_solver(board, 9)
|
| 74 |
+
start_time = time.time()
|
| 75 |
+
sudoku.solve()
|
| 76 |
+
solved_board = sudoku.board
|
| 77 |
+
|
| 78 |
+
# Unwarp
|
| 79 |
+
_, warp_with_nums = draw_digits_on_warped(warped, solved_board, unsolved_board)
|
| 80 |
+
final_img = unwarp_image(warp_with_nums, corners_img, corners_list, time.time() - start_time)
|
| 81 |
+
cv2.imshow("Result", final_img)
|
| 82 |
+
if valid_board(solved_board):
|
| 83 |
+
cv2.waitKey(1000)
|
| 84 |
+
else:
|
| 85 |
+
cv2.imshow("Result", final_img)
|
| 86 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
| 87 |
+
break
|
| 88 |
+
cv2.destroyAllWindows()
|
| 89 |
+
cap.release()
|
sudoku_solve.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Sudoku_solver():
|
| 5 |
+
"""
|
| 6 |
+
Solve Sudoku using Backtracking algorithm
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, board, size):
|
| 10 |
+
self.board = board
|
| 11 |
+
self.size = size
|
| 12 |
+
|
| 13 |
+
def print_board(self):
|
| 14 |
+
"""
|
| 15 |
+
Visualize result board
|
| 16 |
+
"""
|
| 17 |
+
for i in range(len(self.board)):
|
| 18 |
+
if i % 3 == 0 and i != 0:
|
| 19 |
+
print("- - - - - - - - - - - - - ")
|
| 20 |
+
|
| 21 |
+
for j in range(len(self.board[0])):
|
| 22 |
+
if j % 3 == 0 and j != 0:
|
| 23 |
+
print(" | ", end="")
|
| 24 |
+
|
| 25 |
+
if j == 8:
|
| 26 |
+
print(self.board[i][j])
|
| 27 |
+
else:
|
| 28 |
+
print(str(self.board[i][j]) + " ", end="")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def valid(self, num, pos):
|
| 32 |
+
"""
|
| 33 |
+
Check valid board when adding new num in position pos
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# Check valid row
|
| 37 |
+
for j in range(len(self.board[0])):
|
| 38 |
+
if self.board[pos[0]][j] == num and pos[1] != j:
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
# Check valid column
|
| 42 |
+
for i in range(len(self.board)):
|
| 43 |
+
if self.board[i][pos[1]] == num and pos[0] != i:
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
# Check valid box
|
| 47 |
+
# There are 9 boxes
|
| 48 |
+
|
| 49 |
+
box_x = pos[0] // 3
|
| 50 |
+
box_y = pos[1] // 3
|
| 51 |
+
|
| 52 |
+
for i in range(box_x*3, box_x*3+3):
|
| 53 |
+
for j in range(box_y*3, box_y*3+3):
|
| 54 |
+
if self.board[i][j] == num and (i, j) != pos:
|
| 55 |
+
return False
|
| 56 |
+
return True
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def find_empty_cell(self):
|
| 60 |
+
"""
|
| 61 |
+
Find empty cell and return its position
|
| 62 |
+
"""
|
| 63 |
+
for i in range(len(self.board)):
|
| 64 |
+
for j in range(len(self.board[0])):
|
| 65 |
+
if self.board[i][j] == 0:
|
| 66 |
+
return (i, j)
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def solve(self):
|
| 70 |
+
pos = self.find_empty_cell()
|
| 71 |
+
# Base case, complete the board
|
| 72 |
+
if not pos:
|
| 73 |
+
return True
|
| 74 |
+
else:
|
| 75 |
+
row, col = pos
|
| 76 |
+
|
| 77 |
+
for i in range(1, 10):
|
| 78 |
+
if self.valid(i, (row, col)):
|
| 79 |
+
self.board[row][col] = i
|
| 80 |
+
|
| 81 |
+
if self.solve():
|
| 82 |
+
return True
|
| 83 |
+
|
| 84 |
+
self.board[row][col] = 0
|
| 85 |
+
return False
|
threshold.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def preprocess(img):
|
| 6 |
+
"""
|
| 7 |
+
Input: Original image
|
| 8 |
+
Output: Gray-scale processed image
|
| 9 |
+
"""
|
| 10 |
+
# convert RGB to gray-scale
|
| 11 |
+
if (np.array(img).shape[2] != 1):
|
| 12 |
+
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 13 |
+
#Gassian blur
|
| 14 |
+
blured = cv2.GaussianBlur(gray_img, (9,9), 0)
|
| 15 |
+
#set a threshold
|
| 16 |
+
thresh = cv2.adaptiveThreshold(blured, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
|
| 17 |
+
#invert so that the grid line and text are line, the rest is black
|
| 18 |
+
inverted = cv2.bitwise_not(thresh, 0)
|
| 19 |
+
morphy_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2))
|
| 20 |
+
# Opening morphology to remove noise (while dot etc...)
|
| 21 |
+
morph = cv2.morphologyEx(inverted, cv2.MORPH_OPEN, morphy_kernel)
|
| 22 |
+
# dilate to increase border size
|
| 23 |
+
result = cv2.dilate(morph, morphy_kernel, iterations=1)
|
| 24 |
+
return result
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
img = "testimg\sudoku_real_4.jpeg"
|
| 29 |
+
img = cv2.imread(img)
|
| 30 |
+
processed = preprocess(img)
|
| 31 |
+
cv2.imshow("img", cv2.resize(img, (600,600), cv2.INTER_AREA))
|
| 32 |
+
cv2.waitKey(0)
|
utils.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import operator
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def find_corners(polygon, limit_func, compare_func):
|
| 9 |
+
"""
|
| 10 |
+
Input: Rectangle puzzle extract from contours
|
| 11 |
+
Output: One of four cornet point depend on limit_func, compare_func
|
| 12 |
+
# limit_fn is the min or max function
|
| 13 |
+
# compare_fn is the np.add or np.subtract function
|
| 14 |
+
Note: (0,0) point is at the top-left
|
| 15 |
+
|
| 16 |
+
top-left: (x+y) min
|
| 17 |
+
top-right: (x-y) max
|
| 18 |
+
bot-left: (x-y) min
|
| 19 |
+
bot-right: (x+y) max
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
index, _ = limit_func(enumerate([compare_func(ptr[0][0], ptr[0][1]) for ptr in polygon]), key = operator.itemgetter(1))
|
| 23 |
+
|
| 24 |
+
return polygon[index][0][0], polygon[index][0][1]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def draw_circle_at_corners(original, ptr):
|
| 29 |
+
"""
|
| 30 |
+
Helper function to draw circle at corners
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
cv2.circle(original, ptr, 5, (0,255,0), cv2.FILLED)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def grid_line_helper(img, shape_location, length = 10):
|
| 38 |
+
"""
|
| 39 |
+
Helper function to fine vertical, horizontal line
|
| 40 |
+
Find horizontal line: shape_location = 0
|
| 41 |
+
Find vertical line: shape_location = 1
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
clone_img = img.copy()
|
| 45 |
+
row_or_col = clone_img.shape[shape_location]
|
| 46 |
+
|
| 47 |
+
# Find the distance between lines
|
| 48 |
+
size = row_or_col // length
|
| 49 |
+
|
| 50 |
+
# Morphological transfromation to find line
|
| 51 |
+
|
| 52 |
+
# Define morphology kernel
|
| 53 |
+
if shape_location == 0:
|
| 54 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (size,1))
|
| 55 |
+
else:
|
| 56 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1,size))
|
| 57 |
+
|
| 58 |
+
clone_img = cv2.erode(clone_img, kernel)
|
| 59 |
+
clone_img = cv2.dilate(clone_img, kernel)
|
| 60 |
+
|
| 61 |
+
return clone_img
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def draw_line(img, lines):
|
| 66 |
+
"""
|
| 67 |
+
Draw all lines in lines got from cv2.HoughLine()
|
| 68 |
+
"""
|
| 69 |
+
clone_img = img.copy()
|
| 70 |
+
# lines list from cv2.HoughLine() is 3d array
|
| 71 |
+
# Convert to 2d array
|
| 72 |
+
|
| 73 |
+
lines = np.squeeze(lines)
|
| 74 |
+
for rho, theta in lines:
|
| 75 |
+
a = np.cos(theta)
|
| 76 |
+
b = np.sin(theta)
|
| 77 |
+
x0 = a*rho
|
| 78 |
+
y0 = b*rho
|
| 79 |
+
x1 = int(x0 + 1000 * (-b))
|
| 80 |
+
y1 = int(y0 + 1000 * a)
|
| 81 |
+
x2 = int(x0 - 1000 * (-b))
|
| 82 |
+
y2 = int(y0 - 1000 * a)
|
| 83 |
+
#Draw line every loop
|
| 84 |
+
cv2.line(clone_img, (x1,y1), (x2,y2), (255,255,255), 4)
|
| 85 |
+
return clone_img
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def clean_square_helper(img):
|
| 90 |
+
"""
|
| 91 |
+
Clean noises in every square splited
|
| 92 |
+
Input: One of 81 squares
|
| 93 |
+
Output: Cleaned square and boolean var which so that there is number in it
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
if np.isclose(img, 0).sum() / (img.shape[0] * img.shape[1]) >= 0.96:
|
| 97 |
+
return np.zeros_like(img), False
|
| 98 |
+
|
| 99 |
+
# if there is very little white in the region around the center, this means we got an edge accidently
|
| 100 |
+
height, width = img.shape
|
| 101 |
+
mid = width // 2
|
| 102 |
+
if np.isclose(img[:, int(mid - width * 0.38):int(mid + width * 0.38)], 0).sum() / (2 * width * 0.38 * height) >= 0.98:
|
| 103 |
+
return np.zeros_like(img), False
|
| 104 |
+
|
| 105 |
+
# center image
|
| 106 |
+
contours, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 107 |
+
contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
| 108 |
+
x, y, w, h = cv2.boundingRect(contours[0])
|
| 109 |
+
|
| 110 |
+
start_x = (width - w) // 2
|
| 111 |
+
start_y = (height - h) // 2
|
| 112 |
+
new_img = np.zeros_like(img)
|
| 113 |
+
new_img[start_y:start_y + h, start_x:start_x + w] = img[y:y + h, x:x + w]
|
| 114 |
+
|
| 115 |
+
return new_img, True
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def resize_square(clean_square_list):
|
| 120 |
+
"""
|
| 121 |
+
Resize clean squares into 28x28 in order to feed to classifier
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
resized_list = []
|
| 125 |
+
for img in clean_square_list:
|
| 126 |
+
resized = cv2.resize(img, (28,28), interpolation=cv2.INTER_AREA)
|
| 127 |
+
resized_list.append(resized)
|
| 128 |
+
return resized_list
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def resize_square32(clean_square_list):
|
| 133 |
+
"""
|
| 134 |
+
Resize clean squares into 32x32 in order to feed to tf classifier
|
| 135 |
+
"""
|
| 136 |
+
resized_list = []
|
| 137 |
+
for img in clean_square_list:
|
| 138 |
+
resized = cv2.resize(img, (32,32), interpolation=cv2.INTER_AREA)
|
| 139 |
+
resized_list.append(resized)
|
| 140 |
+
return resized_list
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def classify_one_digit(model, resize_square, org_image):
|
| 144 |
+
"""
|
| 145 |
+
Determine whether each square has number by counting number of (not black) pixel and compare to threshold value
|
| 146 |
+
Using classifier to predict number in square
|
| 147 |
+
- Return 0 if the square is blank
|
| 148 |
+
- Return predict digit if the square has number
|
| 149 |
+
"""
|
| 150 |
+
threshold = 0
|
| 151 |
+
if (org_image.shape[0] > 600 or org_image.shape[1] > 600) or (org_image.shape[1] > 600 or org_image.shape[2] > 600):
|
| 152 |
+
threshold = 40
|
| 153 |
+
else:
|
| 154 |
+
threshold = 60
|
| 155 |
+
# Determine blank square
|
| 156 |
+
|
| 157 |
+
if (resize_square != resize_square.min()).sum() < threshold:
|
| 158 |
+
return str(0)
|
| 159 |
+
|
| 160 |
+
model.eval()
|
| 161 |
+
# Convert to shape (1,1,28,28) to be compatible with dataloader for evaluation
|
| 162 |
+
iin = torch.Tensor(resize_square).unsqueeze(0).unsqueeze(0)
|
| 163 |
+
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
out = model(iin)
|
| 166 |
+
# Get index of predict digit
|
| 167 |
+
_, index = torch.max(out, 1)
|
| 168 |
+
|
| 169 |
+
pred_digit = index.item()
|
| 170 |
+
|
| 171 |
+
return str(pred_digit)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def normalize(resized_list):
|
| 176 |
+
"""
|
| 177 |
+
Scale pixel value for recognition
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
return [img/255 for img in resized_list]
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def convert_str_to_board(string, step = 9):
|
| 185 |
+
"""
|
| 186 |
+
Convert recognized string into 2D array for sudoku solving
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
board = []
|
| 190 |
+
for i in range(0, len(string), step):
|
| 191 |
+
board.append([int(char) for char in string[i:i+step]])
|
| 192 |
+
return np.array(board)
|