Spaces:
Running
Running
File size: 1,400 Bytes
2a1cdb8 b25b9cb 2a1cdb8 b25b9cb 2a1cdb8 b25b9cb 2a1cdb8 b25b9cb 2a1cdb8 b25b9cb 2a1cdb8 b25b9cb 2a1cdb8 b25b9cb 2a1cdb8 b25b9cb 2a1cdb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torchvision.transforms as T
import torchvision
import numpy as np
import cv2
import ssl
import os
import pickle
from src.hybrid_model import SimpleCNN
from src import config
def load_data():
ssl._create_default_https_context = ssl._create_unverified_context
ds = torchvision.datasets.MNIST(config.DATA_DIR, train=True, download=True, transform=T.ToTensor())
return ds.data.float() / 255.0, ds.targets
def preprocess_digit(img_data):
img = cv2.cvtColor(img_data.astype(np.uint8), cv2.COLOR_RGBA2GRAY)
if np.sum(img) < 1000: return None
coords = cv2.findNonZero(img)
x, y, w, h = cv2.boundingRect(coords)
img_crop = img[y:y+h, x:x+w]
f = 20 / max(w, h)
img_res = cv2.resize(img_crop, (int(w*f), int(h*f)), interpolation=cv2.INTER_AREA)
final = np.zeros((28, 28), dtype=np.uint8)
py, px = (28 - img_res.shape[0]) // 2, (28 - img_res.shape[1]) // 2
final[py:py+img_res.shape[0], px:px+img_res.shape[1]] = img_res
return torch.tensor(final).float() / 255.0
def load_models():
if not (os.path.exists(config.SVD_MODEL_PATH) and os.path.exists(config.CNN_MODEL_PATH)):
return None, None
with open(config.SVD_MODEL_PATH, "rb") as f:
svd = pickle.load(f)
cnn = SimpleCNN()
cnn.load_state_dict(torch.load(config.CNN_MODEL_PATH, map_location="cpu"))
cnn.eval()
return svd, cnn
|