|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import os |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TRAIN_DIR = "/mnt/c/Users/krzsa/IdeaProjects/Agents-Course-Assignment/train-data" |
|
|
TRAIN_DIR_BLACK = f"{TRAIN_DIR}/black" |
|
|
TRAIN_DIR_WHITE = f"{TRAIN_DIR}/white" |
|
|
TRAIN_DIR_EMPTY = f"{TRAIN_DIR}/empty" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TRAIN_DATA = [ |
|
|
(f"{TRAIN_DIR_EMPTY}/1_001.png", "1"), |
|
|
(f"{TRAIN_DIR_EMPTY}/1_002.png", "1"), |
|
|
(f"{TRAIN_DIR_BLACK}/b_001.png", "b"), |
|
|
(f"{TRAIN_DIR_BLACK}/b_002.png", "b"), |
|
|
(f"{TRAIN_DIR_BLACK}/k_001.png", "k"), |
|
|
(f"{TRAIN_DIR_BLACK}/k_002.png", "k"), |
|
|
(f"{TRAIN_DIR_BLACK}/n_001.png", "n"), |
|
|
(f"{TRAIN_DIR_BLACK}/n_002.png", "n"), |
|
|
(f"{TRAIN_DIR_BLACK}/p_001.png", "p"), |
|
|
(f"{TRAIN_DIR_BLACK}/p_002.png", "p"), |
|
|
(f"{TRAIN_DIR_BLACK}/q_001.png", "q"), |
|
|
(f"{TRAIN_DIR_BLACK}/q_002.png", "q"), |
|
|
(f"{TRAIN_DIR_BLACK}/r_001.png", "r"), |
|
|
(f"{TRAIN_DIR_BLACK}/r_002.png", "r"), |
|
|
(f"{TRAIN_DIR_WHITE}/B_001.png", "B"), |
|
|
(f"{TRAIN_DIR_WHITE}/B_002.png", "B"), |
|
|
(f"{TRAIN_DIR_WHITE}/K_001.png", "K"), |
|
|
(f"{TRAIN_DIR_WHITE}/K_002.png", "K"), |
|
|
(f"{TRAIN_DIR_WHITE}/N_001.png", "N"), |
|
|
(f"{TRAIN_DIR_WHITE}/N_002.png", "N"), |
|
|
(f"{TRAIN_DIR_WHITE}/P_001.png", "P"), |
|
|
(f"{TRAIN_DIR_WHITE}/P_002.png", "P"), |
|
|
(f"{TRAIN_DIR_WHITE}/Q_001.png", "Q"), |
|
|
(f"{TRAIN_DIR_WHITE}/Q_002.png", "Q"), |
|
|
(f"{TRAIN_DIR_WHITE}/R_001.png", "R"), |
|
|
(f"{TRAIN_DIR_WHITE}/R_002.png", "R"), |
|
|
] |
|
|
|
|
|
TEST_DATA = TRAIN_DATA |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CNNModel(nn.Module): |
|
|
|
|
|
def __init__(self, _model_name, _model_dir): |
|
|
super(CNNModel, self).__init__() |
|
|
self.name = _model_name |
|
|
self.model_dir = _model_dir |
|
|
print("***KS*** Model: Creating layers") |
|
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2) |
|
|
|
|
|
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(8 * 8 * 64, 1024) |
|
|
self.dropout = nn.Dropout(p=0.5) |
|
|
|
|
|
self.fc2 = nn.Linear(1024, 13) |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
|
|
|
|
model_name = os.path.join(self.model_dir, self.name) |
|
|
print(f"***KS*** Checking pre-trained model: '{model_name}'") |
|
|
if os.path.exists(model_name): |
|
|
print(f"***KS*** Model '{model_name}' exists, loading weights ...") |
|
|
self.load_state_dict(torch.load(model_name, map_location=self.device)) |
|
|
print("*** KS *** Model loaded.") |
|
|
else: |
|
|
print(f"*** KS *** Model file '{model_name}' not found. Initializing weights with random values") |
|
|
|
|
|
nn.init.trunc_normal_(self.conv1.weight, std=0.1) |
|
|
nn.init.constant_(self.conv1.bias, 0.1) |
|
|
|
|
|
nn.init.trunc_normal_(self.conv2.weight, std=0.1) |
|
|
nn.init.constant_(self.conv2.bias, 0.1) |
|
|
|
|
|
nn.init.trunc_normal_(self.fc1.weight, std=0.1) |
|
|
nn.init.constant_(self.fc1.bias, 0.1) |
|
|
|
|
|
nn.init.trunc_normal_(self.fc2.weight, std=0.1) |
|
|
nn.init.constant_(self.fc2.bias, 0.1) |
|
|
|
|
|
self.to(self.device) |
|
|
|
|
|
def save_weights(self): |
|
|
print(f"***KS*** Saving model ...") |
|
|
|
|
|
os.makedirs('saved_models', exist_ok=True) |
|
|
model_save_path = f"../saved_models/{self.name}.pth" |
|
|
torch.save(self.state_dict(), model_save_path) |
|
|
print(f'*** KS *** Model saved in file: {model_save_path}') |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
print("***KS*** Model: Executing forward calculations") |
|
|
|
|
|
|
|
|
print(f"***KS*** [0] {x.shape}") |
|
|
|
|
|
x = self.conv1(x) |
|
|
print(f"***KS*** [1] {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = F.relu(x) |
|
|
print(f"***KS*** [2] {x.shape}") |
|
|
|
|
|
|
|
|
x = F.max_pool2d(x, kernel_size=2, stride=2) |
|
|
print(f"***KS*** [3] {x.shape}") |
|
|
|
|
|
|
|
|
x = F.relu(self.conv2(x)) |
|
|
print(f"***KS*** [4] {x.shape}") |
|
|
|
|
|
x = F.max_pool2d(x, kernel_size=2, stride=2) |
|
|
|
|
|
|
|
|
print(f"***KS*** [5] {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.view(-1, 8 * 8 * 64) |
|
|
print(f"***KS*** [6] {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.fc1(x) |
|
|
print(f"***KS*** [7] {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
x = F.relu(x) |
|
|
print(f"***KS*** [8] {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
x = self.dropout(x) |
|
|
print(f"***KS*** [9] {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
x = self.fc2(x) |
|
|
print(f"***KS*** [10] {x.shape}") |
|
|
|
|
|
|
|
|
|
|
|
return x |
|
|
|
|
|
def get_device(self): |
|
|
return self.device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChessDataset(Dataset): |
|
|
CHESS_PIECES = '1KQRBNPkqrbnp' |
|
|
|
|
|
def __init__(self, image_train_date): |
|
|
|
|
|
|
|
|
self.num_images = len(image_train_date) |
|
|
|
|
|
self.images = np.zeros([self.num_images, 32, 32], dtype=np.uint8) |
|
|
self.labels = np.zeros([self.num_images], dtype=np.int64) |
|
|
|
|
|
for i, image_file_path_and_label in enumerate(image_train_date): |
|
|
|
|
|
with Image.open(image_file_path_and_label[0]) as img: |
|
|
img = img.convert('L') |
|
|
self.images[i, :, :] = np.array(img, dtype=np.uint8) |
|
|
|
|
|
|
|
|
self.labels[i] = self.__get_piece_index_from_label__(image_file_path_and_label[1]) |
|
|
|
|
|
print("***KS*** Done loading training data") |
|
|
|
|
|
def __get_piece_index_from_label__(self, label) -> int: |
|
|
return self.CHESS_PIECES.find(label) |
|
|
|
|
|
def get_piece_label(self, idx) -> str: |
|
|
return self.CHESS_PIECES[idx] |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_images |
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = self.images[idx].astype('float32') / 255.0 |
|
|
image = np.expand_dims(image, axis=0) |
|
|
label = self.labels[idx] |
|
|
return torch.tensor(image, dtype=torch.float32), label |
|
|
|
|
|
|
|
|
class ChessImagesDataset(Dataset): |
|
|
CHESS_PIECES = '1KQRBNPkqrbnp' |
|
|
|
|
|
def __init__(self, images): |
|
|
self.num_images = len(images) |
|
|
self.images = images |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_images |
|
|
|
|
|
def get_piece_label(self, idx) -> str: |
|
|
return self.CHESS_PIECES[idx] |
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = self.images[idx].astype('float32') / 255.0 |
|
|
image = np.expand_dims(image, axis=0) |
|
|
label = "" |
|
|
return torch.tensor(image, dtype=torch.float32), label |
|
|
|
|
|
|
|
|
class ChessPiecesRecognition: |
|
|
def __init__(self, _model_name, _model_dir): |
|
|
print(f"***KS*** Chess pieces recognition initializing ...") |
|
|
self.model = CNNModel(_model_name, _model_dir) |
|
|
self.__load_train_data__() |
|
|
|
|
|
def __load_train_data__(self): |
|
|
print(f"*** KS *** loading training data") |
|
|
|
|
|
|
|
|
print(f"Loading {len(TRAIN_DATA)} Training tiles", end='') |
|
|
train_dataset = ChessDataset(TRAIN_DATA) |
|
|
|
|
|
|
|
|
print(f"\n*** KS *** Loading {len(TEST_DATA)} Testing tiles", end='') |
|
|
test_dataset = ChessDataset(TEST_DATA) |
|
|
print() |
|
|
|
|
|
batch_size = 64 |
|
|
|
|
|
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
|
self.test_loader = DataLoader(test_dataset, batch_size=batch_size) |
|
|
|
|
|
def train(self): |
|
|
print(f"***KS*** Training chess pieces recognition") |
|
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = optim.Adam(self.model.parameters(), lr=1e-4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
do_training = True |
|
|
epochs = 100 |
|
|
|
|
|
if do_training: |
|
|
|
|
|
self.model.train() |
|
|
print(f"*** KS *** Starting training for {epochs} epochs...") |
|
|
for epoch in range(epochs): |
|
|
running_loss = 0.0 |
|
|
print(f"***KS*** Epoch: {epoch}") |
|
|
for i, (inputs, labels) in enumerate(self.train_loader): |
|
|
|
|
|
inputs = inputs.to(self.model.get_device()) |
|
|
labels = labels.to(self.model.get_device()) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
outputs = self.model(inputs) |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
running_loss += loss.item() |
|
|
if (i + 1) % 10 == 0: |
|
|
print(f'*** KS *** Epoch [{epoch +1}/{epochs}], Step [{i +1}/{len(self.train_loader)}], ' |
|
|
f'Loss: {running_loss / 10:.4f}') |
|
|
running_loss = 0.0 |
|
|
|
|
|
print('Finished Training') |
|
|
|
|
|
self.model.save_weights() |
|
|
|
|
|
def eval(self): |
|
|
|
|
|
self.model.eval() |
|
|
correct = 0 |
|
|
total = 0 |
|
|
with torch.no_grad(): |
|
|
for inputs, labels in self.test_loader: |
|
|
|
|
|
inputs = inputs.to(self.model.get_device()) |
|
|
labels = labels.to(self.model.get_device()) |
|
|
|
|
|
outputs = self.model(inputs) |
|
|
print(f"***KS*** Got model outputs: \nshape: {outputs.shape}\n{outputs}") |
|
|
|
|
|
labels_detected = np.argmax(outputs.cpu(), axis=1) |
|
|
print(f"***KS*** Got labels idx detected: \nshape: {labels_detected.shape}\n{labels_detected}") |
|
|
|
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
|
|
|
test_accuracy = correct / total |
|
|
print(f'Accuracy on test set: {test_accuracy * 100:.2f}%\n') |
|
|
|
|
|
def classify_pieces(self, images): |
|
|
dataset = ChessImagesDataset(images) |
|
|
loader = DataLoader(dataset, batch_size=64) |
|
|
|
|
|
|
|
|
labels_str = "" |
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
for inputs, labels in loader: |
|
|
|
|
|
inputs = inputs.to(self.model.get_device()) |
|
|
|
|
|
outputs = self.model(inputs) |
|
|
print(f"***KS*** Got model outputs: \nshape: {outputs.shape}") |
|
|
|
|
|
labels_detected = np.argmax(outputs.cpu(), axis=1) |
|
|
print(f"***KS*** Got labels idx detected: \nshape: {labels_detected.shape}\n{labels_detected}") |
|
|
|
|
|
labels = [dataset.get_piece_label(ix) for ix in labels_detected] |
|
|
labels_str = ''.join(labels) |
|
|
|
|
|
return labels_str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|