Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import argparse | |
| import os | |
| import warnings | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from DocScanner.model import DocScanner | |
| from DocScanner.seg import U2NETP | |
| from PIL import Image | |
| warnings.filterwarnings("ignore") | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super(Net, self).__init__() | |
| self.msk = U2NETP(3, 1) | |
| self.bm = DocScanner() # 矫正 | |
| def forward(self, x): | |
| msk, _1, _2, _3, _4, _5, _6 = self.msk(x) | |
| msk = (msk > 0.5).float() | |
| x = msk * x | |
| bm = self.bm(x, iters=12, test_mode=True) | |
| bm = (2 * (bm / 286.8) - 1) * 0.99 | |
| return bm, msk | |
| def reload_seg_model(cuda, model, path=""): | |
| if not bool(path): | |
| return model | |
| else: | |
| model_dict = model.state_dict() | |
| pretrained_dict = torch.load(path, map_location=cuda) | |
| pretrained_dict = { | |
| k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict | |
| } | |
| model_dict.update(pretrained_dict) | |
| model.load_state_dict(model_dict) | |
| return model | |
| def reload_rec_model(cuda, model, path=""): | |
| if not bool(path): | |
| return model | |
| else: | |
| model_dict = model.state_dict() | |
| pretrained_dict = torch.load(path, map_location=cuda) | |
| pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} | |
| model_dict.update(pretrained_dict) | |
| model.load_state_dict(model_dict) | |
| return model |