Spaces:
Running
Running
| import torch | |
| import torchvision.transforms as T | |
| import torchvision | |
| import numpy as np | |
| import cv2 | |
| import ssl | |
| import os | |
| import pickle | |
| from src.hybrid_model import SimpleCNN | |
| from src import config | |
| def load_data(): | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| ds = torchvision.datasets.MNIST(config.DATA_DIR, train=True, download=True, transform=T.ToTensor()) | |
| return ds.data.float() / 255.0, ds.targets | |
| def preprocess_digit(img_data): | |
| img = cv2.cvtColor(img_data.astype(np.uint8), cv2.COLOR_RGBA2GRAY) | |
| if np.sum(img) < 1000: return None | |
| coords = cv2.findNonZero(img) | |
| x, y, w, h = cv2.boundingRect(coords) | |
| img_crop = img[y:y+h, x:x+w] | |
| f = 20 / max(w, h) | |
| img_res = cv2.resize(img_crop, (int(w*f), int(h*f)), interpolation=cv2.INTER_AREA) | |
| final = np.zeros((28, 28), dtype=np.uint8) | |
| py, px = (28 - img_res.shape[0]) // 2, (28 - img_res.shape[1]) // 2 | |
| final[py:py+img_res.shape[0], px:px+img_res.shape[1]] = img_res | |
| return torch.tensor(final).float() / 255.0 | |
| def load_models(): | |
| if not (os.path.exists(config.SVD_MODEL_PATH) and os.path.exists(config.CNN_MODEL_PATH)): | |
| return None, None | |
| with open(config.SVD_MODEL_PATH, "rb") as f: | |
| svd = pickle.load(f) | |
| cnn = SimpleCNN() | |
| cnn.load_state_dict(torch.load(config.CNN_MODEL_PATH, map_location="cpu")) | |
| cnn.eval() | |
| return svd, cnn | |