UnderWater / ModelLoader.py
Yarflam's picture
Fix Gradio components
62b9889
from models import create_model
from util.get_transform import get_transform
from util.util import tensor2im
from PIL import Image
import os
ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints')
class Options(object):
def __init__(self, *initial_data, **kwargs):
for dictionary in initial_data:
for key in dictionary:
setattr(self, key, dictionary[key])
for key in kwargs:
setattr(self, key, kwargs[key])
class ModelLoader:
def __init__(self, gpu_ids='', max_img_wh=512) -> None:
self.opt = Options({
'isGradio': True, # Custom
'name': 'original', # Checkpoints name
'checkpoints_dir': ckp_path, # Checkpoint folder
'gpu_ids': gpu_ids.split(',') if gpu_ids else [],
'init_gain': 0.02, # Scaling Factor
'init_type': 'xavier', # list: 'normal', 'xavier', 'kaiming', 'orthogonal'
'input_nc': 3, # 3 -> RGB, 1 -> Grayscale
'output_nc': 3,
'isTrain': False,
'model': 'cwr',
'nce_idt': False,
'nce_layers': '0',
'ndf': 64, # Nb of discrim filters in the first conv layer
'netD': 'basic',
'netG': 'resnet_9blocks',
'netF': 'mlp_sample',
'netF_nc': 256,
'ngf': 64, # Nb of gen filters in the last conv layer
'no_antialias_up': False,
'no_antialias': False,
'no_dropout': True,
'normD': 'instance',
'normG': 'instance',
'preprocess': 'yarflam_auto', # see more: util.get_transform
'dataroot': 'placeholder',
'num_threads': 1, # test code only supports num_threads = 1
'batch_size': 1, # test code only supports batch_size = 1
'serial_batches': False, # disable data shuffling; comment this line if results on randomly chosen images are needed.
'no_flip': True, # no flip; comment this line if results on flipped images are needed.
'display_id': -1, # no visdom display; the test code saves the results to a HTML file.
'direction': 'AtoB', # inference
'flip_equivariance': False,
'load_size': 1680, # not used
'crop_size': 512, # not used
'yarflam_img_wh': max_img_wh, # max width|height + auto scale down
})
self.transform = get_transform(self.opt, grayscale=False)
self.model = None
def load(self) -> None:
self.model = create_model(self.opt)
self.model.load_networks('latest')
def inference(self, src='', image_pil=None):
if self.model == None: self.load()
# Loading
if isinstance(image_pil, Image.Image):
img = self.transform(image_pil.convert('RGB')).unsqueeze(0)
else:
if not os.path.isfile(src):
raise Exception('The image %s is not found!' % src)
print('Loading the image %s' % src)
source = Image.open(src).convert('RGB')
img = self.transform(source).unsqueeze(0)
print(img.shape)
# Inference
self.model.set_input({
'A': img, 'A_paths': src,
'B': img, 'B_paths': src
})
self.model.forward()
out_data = list(self.model.get_current_visuals().items())[1][1]
out_img = Image.fromarray(tensor2im(out_data))
return out_img