import torch import torchvision.transforms as T import torchvision import numpy as np import cv2 import ssl import os import pickle from src.hybrid_model import SimpleCNN from src import config def load_data(): ssl._create_default_https_context = ssl._create_unverified_context ds = torchvision.datasets.MNIST(config.DATA_DIR, train=True, download=True, transform=T.ToTensor()) return ds.data.float() / 255.0, ds.targets def preprocess_digit(img_data): img = cv2.cvtColor(img_data.astype(np.uint8), cv2.COLOR_RGBA2GRAY) if np.sum(img) < 1000: return None coords = cv2.findNonZero(img) x, y, w, h = cv2.boundingRect(coords) img_crop = img[y:y+h, x:x+w] f = 20 / max(w, h) img_res = cv2.resize(img_crop, (int(w*f), int(h*f)), interpolation=cv2.INTER_AREA) final = np.zeros((28, 28), dtype=np.uint8) py, px = (28 - img_res.shape[0]) // 2, (28 - img_res.shape[1]) // 2 final[py:py+img_res.shape[0], px:px+img_res.shape[1]] = img_res return torch.tensor(final).float() / 255.0 def load_models(): if not (os.path.exists(config.SVD_MODEL_PATH) and os.path.exists(config.CNN_MODEL_PATH)): return None, None with open(config.SVD_MODEL_PATH, "rb") as f: svd = pickle.load(f) cnn = SimpleCNN() cnn.load_state_dict(torch.load(config.CNN_MODEL_PATH, map_location="cpu")) cnn.eval() return svd, cnn