File size: 1,400 Bytes
2a1cdb8
b25b9cb
2a1cdb8
 
 
 
 
 
 
 
 
 
 
b25b9cb
 
2a1cdb8
 
b25b9cb
 
2a1cdb8
 
 
 
 
b25b9cb
 
2a1cdb8
b25b9cb
 
 
 
2a1cdb8
b25b9cb
 
2a1cdb8
b25b9cb
2a1cdb8
 
b25b9cb
2a1cdb8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torchvision.transforms as T
import torchvision
import numpy as np
import cv2
import ssl
import os
import pickle
from src.hybrid_model import SimpleCNN
from src import config

def load_data():
    ssl._create_default_https_context = ssl._create_unverified_context
    ds = torchvision.datasets.MNIST(config.DATA_DIR, train=True, download=True, transform=T.ToTensor())
    return ds.data.float() / 255.0, ds.targets

def preprocess_digit(img_data):
    img = cv2.cvtColor(img_data.astype(np.uint8), cv2.COLOR_RGBA2GRAY)
    if np.sum(img) < 1000: return None
    
    coords = cv2.findNonZero(img)
    x, y, w, h = cv2.boundingRect(coords)
    img_crop = img[y:y+h, x:x+w]
    
    f = 20 / max(w, h)
    img_res = cv2.resize(img_crop, (int(w*f), int(h*f)), interpolation=cv2.INTER_AREA)
    
    final = np.zeros((28, 28), dtype=np.uint8)
    py, px = (28 - img_res.shape[0]) // 2, (28 - img_res.shape[1]) // 2
    final[py:py+img_res.shape[0], px:px+img_res.shape[1]] = img_res
    return torch.tensor(final).float() / 255.0

def load_models():
    if not (os.path.exists(config.SVD_MODEL_PATH) and os.path.exists(config.CNN_MODEL_PATH)):
        return None, None
    with open(config.SVD_MODEL_PATH, "rb") as f:
        svd = pickle.load(f)
    cnn = SimpleCNN()
    cnn.load_state_dict(torch.load(config.CNN_MODEL_PATH, map_location="cpu"))
    cnn.eval()
    return svd, cnn