Spaces:
Paused
Paused
File size: 3,466 Bytes
c954f09 a15cce2 a8eef7d a15cce2 c954f09 a8eef7d c954f09 a15cce2 a8eef7d c954f09 a8eef7d c954f09 a15cce2 a8eef7d c954f09 a8eef7d c954f09 a8eef7d c954f09 a15cce2 a8eef7d c954f09 a15cce2 c954f09 62b9889 a15cce2 62b9889 a15cce2 a8eef7d a15cce2 a8eef7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from models import create_model
from util.get_transform import get_transform
from util.util import tensor2im
from PIL import Image
import os
ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints')
class Options(object):
def __init__(self, *initial_data, **kwargs):
for dictionary in initial_data:
for key in dictionary:
setattr(self, key, dictionary[key])
for key in kwargs:
setattr(self, key, kwargs[key])
class ModelLoader:
def __init__(self, gpu_ids='', max_img_wh=512) -> None:
self.opt = Options({
'isGradio': True, # Custom
'name': 'original', # Checkpoints name
'checkpoints_dir': ckp_path, # Checkpoint folder
'gpu_ids': gpu_ids.split(',') if gpu_ids else [],
'init_gain': 0.02, # Scaling Factor
'init_type': 'xavier', # list: 'normal', 'xavier', 'kaiming', 'orthogonal'
'input_nc': 3, # 3 -> RGB, 1 -> Grayscale
'output_nc': 3,
'isTrain': False,
'model': 'cwr',
'nce_idt': False,
'nce_layers': '0',
'ndf': 64, # Nb of discrim filters in the first conv layer
'netD': 'basic',
'netG': 'resnet_9blocks',
'netF': 'mlp_sample',
'netF_nc': 256,
'ngf': 64, # Nb of gen filters in the last conv layer
'no_antialias_up': False,
'no_antialias': False,
'no_dropout': True,
'normD': 'instance',
'normG': 'instance',
'preprocess': 'yarflam_auto', # see more: util.get_transform
'dataroot': 'placeholder',
'num_threads': 1, # test code only supports num_threads = 1
'batch_size': 1, # test code only supports batch_size = 1
'serial_batches': False, # disable data shuffling; comment this line if results on randomly chosen images are needed.
'no_flip': True, # no flip; comment this line if results on flipped images are needed.
'display_id': -1, # no visdom display; the test code saves the results to a HTML file.
'direction': 'AtoB', # inference
'flip_equivariance': False,
'load_size': 1680, # not used
'crop_size': 512, # not used
'yarflam_img_wh': max_img_wh, # max width|height + auto scale down
})
self.transform = get_transform(self.opt, grayscale=False)
self.model = None
def load(self) -> None:
self.model = create_model(self.opt)
self.model.load_networks('latest')
def inference(self, src='', image_pil=None):
if self.model == None: self.load()
# Loading
if isinstance(image_pil, Image.Image):
img = self.transform(image_pil.convert('RGB')).unsqueeze(0)
else:
if not os.path.isfile(src):
raise Exception('The image %s is not found!' % src)
print('Loading the image %s' % src)
source = Image.open(src).convert('RGB')
img = self.transform(source).unsqueeze(0)
print(img.shape)
# Inference
self.model.set_input({
'A': img, 'A_paths': src,
'B': img, 'B_paths': src
})
self.model.forward()
out_data = list(self.model.get_current_visuals().items())[1][1]
out_img = Image.fromarray(tensor2im(out_data))
return out_img |