Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import argparse | |
| import glob | |
| import os | |
| import warnings | |
| import cv2 | |
| import numpy as np | |
| import skimage.io as io | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from .GeoTr import U2NETP, GeoTr | |
| warnings.filterwarnings("ignore") | |
| class GeoTrP(nn.Module): | |
| def __init__(self): | |
| super(GeoTrP, self).__init__() | |
| self.GeoTr = GeoTr() | |
| def forward(self, x): | |
| bm = self.GeoTr(x) # [0] | |
| bm = 2 * (bm / 288) - 1 | |
| bm = (bm + 1) / 2 * 2560 | |
| bm = F.interpolate(bm, size=(2560, 2560), mode="bilinear", align_corners=True) | |
| return bm | |
| def reload_model(model, path=""): | |
| if not bool(path): | |
| return model | |
| else: | |
| model_dict = model.state_dict() | |
| pretrained_dict = torch.load(path, map_location="cuda:0") | |
| print(len(pretrained_dict.keys())) | |
| print(len(pretrained_dict.keys())) | |
| model_dict.update(pretrained_dict) | |
| model.load_state_dict(model_dict) | |
| return model | |