Spaces:
Running
Running
| import argparse | |
| from pathlib import Path | |
| from typing import Sequence, Union | |
| from PIL import Image | |
| import torch | |
| import torchvision | |
| import numpy as np | |
| import re | |
| import chess | |
| from fentoimage.board import BoardImage | |
| import image_to_fen.util as util | |
| STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "image-to-fen" | |
| MODEL_FILE = "model.pt" | |
| class ImageToFen: | |
| """Takes image of chess board and returns FEN string.""" | |
| def __init__(self, model_path=None): | |
| if model_path is None: | |
| model_path = STAGED_MODEL_DIRNAME / MODEL_FILE | |
| self.model = torch.jit.load(model_path) | |
| def predict(self, image: Union[str, Path, Image.Image]) -> str: | |
| """Predict FEN string for image of chess board.""" | |
| image = image | |
| if not isinstance(image, Image.Image): | |
| image = util.read_image_pil(image, grayscale=True) | |
| image = image.resize((200, 200)) | |
| image = torchvision.transforms.PILToTensor()(image)/255 | |
| pred = self.model([image])[1][0] | |
| nms_pred = apply_nms(pred, iou_thresh=0.2) | |
| pred_str = boxes_labels_to_fen(nms_pred['boxes'], nms_pred['labels']) | |
| return pred_str | |
| def apply_nms(orig_prediction, iou_thresh=0.3): | |
| # torchvision returns the indices of the bboxes to keep | |
| keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh) | |
| final_prediction = orig_prediction | |
| final_prediction['boxes'] = final_prediction['boxes'][keep] | |
| final_prediction['scores'] = final_prediction['scores'][keep] | |
| final_prediction['labels'] = final_prediction['labels'][keep] | |
| return final_prediction | |
| def boxes_labels_to_fen(boxes, labels, square_size=25): | |
| boxes = torch.round(boxes / 25) * 25 | |
| eye = np.eye(13) | |
| one_hot = onehot_from_fen("8-8-8-8-8-8-8-8") | |
| for i, box in enumerate(boxes): | |
| x = box[0] | |
| y = box[1] | |
| ind = int((x / square_size) + (y / square_size) * 8) | |
| if (ind >= 64): | |
| continue | |
| one_hot[ind] = eye[12 - labels[i]].reshape((1, 13)).astype(int) | |
| return fen_from_onehot(one_hot) | |
| def onehot_from_fen(fen): | |
| piece_symbols = 'prbnkqPRBNKQ' | |
| eye = np.eye(13) | |
| output = np.empty((0, 13)) | |
| fen = re.sub('[-]', '', fen) | |
| for char in fen: | |
| if(char in '12345678'): | |
| output = np.append( | |
| output, np.tile(eye[12], (int(char), 1)), axis=0) | |
| else: | |
| idx = piece_symbols.index(char) | |
| output = np.append(output, eye[idx].reshape((1, 13)), axis=0) | |
| return output | |
| def fen_from_onehot(one_hot): | |
| piece_symbols = 'prbnkqPRBNKQ' | |
| output = '' | |
| for j in range(8): | |
| for i in range(8): | |
| idx = np.where(one_hot[j*8 + i]==1)[0][0] | |
| if(idx == 12): | |
| output += ' ' | |
| else: | |
| output += piece_symbols[idx] | |
| if(j != 7): | |
| output += '-' | |
| for i in range(8, 0, -1): | |
| output = output.replace(' ' * i, str(i)) | |
| return output | |
| def fen_and_image(input): | |
| itf = ImageToFen() | |
| output = itf.predict(input) | |
| fen = output.replace('-', '/') | |
| renderer = BoardImage(fen) | |
| image = renderer.render() | |
| return fen, image | |
| def main(): | |
| """Run prediction on image.""" | |
| parser = argparse.ArgumentParser(description="Predict FEN string for image of chess board.") | |
| parser.add_argument("image", type=Path, help="Path to image file.") | |
| parser.add_argument("--model-path", type=Path, help="Path to model file.") | |
| args = parser.parse_args() | |
| image_to_fen = ImageToFen(args.model_path) | |
| pred = image_to_fen.predict(args.image) | |
| print(f"Prediction: {pred}") | |
| # image_to_fen/tests/support/boards/phpSrRLQ1.png | |
| if __name__ == "__main__": | |
| main() |