File size: 3,466 Bytes
c954f09
a15cce2
a8eef7d
a15cce2
c954f09
 
 
 
 
 
 
 
 
 
 
 
 
a8eef7d
c954f09
a15cce2
a8eef7d
 
 
 
 
 
c954f09
 
 
 
 
a8eef7d
c954f09
 
a15cce2
 
a8eef7d
 
 
c954f09
 
 
a8eef7d
 
 
c954f09
a8eef7d
c954f09
 
a15cce2
 
a8eef7d
 
 
c954f09
a15cce2
 
c954f09
 
 
62b9889
a15cce2
 
62b9889
 
 
 
 
 
 
 
a15cce2
 
a8eef7d
 
 
 
a15cce2
a8eef7d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from models import create_model
from util.get_transform import get_transform
from util.util import tensor2im
from PIL import Image
import os

ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints')

class Options(object):
    def __init__(self, *initial_data, **kwargs):
        for dictionary in initial_data:
            for key in dictionary:
                setattr(self, key, dictionary[key])
        for key in kwargs:
            setattr(self, key, kwargs[key])

class ModelLoader:
    def __init__(self, gpu_ids='', max_img_wh=512) -> None:
        self.opt = Options({
            'isGradio': True, # Custom
            'name': 'original', # Checkpoints name
            'checkpoints_dir': ckp_path, # Checkpoint folder
            'gpu_ids': gpu_ids.split(',') if gpu_ids else [],
            'init_gain': 0.02, # Scaling Factor
            'init_type': 'xavier', # list: 'normal', 'xavier', 'kaiming', 'orthogonal'
            'input_nc': 3, # 3 -> RGB, 1 -> Grayscale
            'output_nc': 3,
            'isTrain': False,
            'model': 'cwr',
            'nce_idt': False,
            'nce_layers': '0',
            'ndf': 64, # Nb of discrim filters in the first conv layer
            'netD': 'basic',
            'netG': 'resnet_9blocks',
            'netF': 'mlp_sample',
            'netF_nc': 256,
            'ngf': 64, # Nb of gen filters in the last conv layer
            'no_antialias_up': False,
            'no_antialias': False,
            'no_dropout': True,
            'normD': 'instance',
            'normG': 'instance',
            'preprocess': 'yarflam_auto', # see more: util.get_transform
            'dataroot': 'placeholder',
            'num_threads': 1, # test code only supports num_threads = 1
            'batch_size': 1, # test code only supports batch_size = 1
            'serial_batches': False, # disable data shuffling; comment this line if results on randomly chosen images are needed.
            'no_flip': True, # no flip; comment this line if results on flipped images are needed.
            'display_id': -1, # no visdom display; the test code saves the results to a HTML file.
            'direction': 'AtoB', # inference
            'flip_equivariance': False,
            'load_size': 1680, # not used
            'crop_size': 512, # not used
            'yarflam_img_wh': max_img_wh, # max width|height + auto scale down
        })
        self.transform = get_transform(self.opt, grayscale=False)
        self.model = None
    def load(self) -> None:
        self.model = create_model(self.opt)
        self.model.load_networks('latest')
    def inference(self, src='', image_pil=None):
        if self.model == None: self.load()
        # Loading
        if isinstance(image_pil, Image.Image):
            img = self.transform(image_pil.convert('RGB')).unsqueeze(0)
        else:
            if not os.path.isfile(src):
                raise Exception('The image %s is not found!' % src)
            print('Loading the image %s' % src)
            source = Image.open(src).convert('RGB')
            img = self.transform(source).unsqueeze(0)
        print(img.shape)
        # Inference
        self.model.set_input({
            'A': img, 'A_paths': src,
            'B': img, 'B_paths': src
        })
        self.model.forward()
        out_data = list(self.model.get_current_visuals().items())[1][1]
        out_img = Image.fromarray(tensor2im(out_data))
        return out_img