File size: 12,091 Bytes
0f52c9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import torch
from models.networks.base_network import BaseNetwork
from models.networks.loss import *
from models.networks.discriminator import *
from models.networks.encoder import ConvEncoder
from models.networks.generator import TSITGenerator, Pix2PixHDGenerator
from models.networks.rafael_generator import RafaelGenerator
import torch.nn as nn
from .style_encoder import SimpleStyleEncoder

SWIN_COND_D_IMPORTED = False
SwinTransformerConditionalDiscriminator = None
try:
    from models.networks.swin_discriminator import SwinTransformerConditionalDiscriminator
    SWIN_COND_D_IMPORTED = True
    print("Successfully imported SwinTransformerConditionalDiscriminator.")
except ImportError as e:
    print(f"Warning: Could not import SwinTransformerConditionalDiscriminator: {e}. Swin D will not be available.")
except Exception as e:
    print(f"Warning: An unexpected error occurred while importing SwinTransformerConditionalDiscriminator: {e}")


def find_class_in_module(target_cls_name, module_globals):
    target_cls_name_lower_no_underscore = target_cls_name.lower().replace('_', '')
    for name, cls_obj in module_globals.items():
        if isinstance(cls_obj, type) and name.lower().replace('_', '') == target_cls_name_lower_no_underscore:
            return cls_obj

    available_classes = [name for name, obj in module_globals.items() if isinstance(obj, type)]
    print(
        f"Error: Class '{target_cls_name_lower_no_underscore}' (derived from '{target_cls_name}') not found in module_globals.")
    print(f"Available classes (case-insensitive, no underscore): {[ac.lower().replace('_', '') for ac in available_classes]}")
    print(f"Original available classes: {available_classes}")
    raise ValueError(
        f"In current module, there should be a class named '{target_cls_name}' (comparison is case-insensitive and ignores underscores)."
    )


def modify_commandline_options(parser, is_train):
    opt, _ = parser.parse_known_args()

    # Generator options
    if hasattr(opt, 'netG') and opt.netG:
        try:
            netG_cls = find_class_in_module(opt.netG, globals())
            if hasattr(netG_cls, 'modify_commandline_options'):
                parser = netG_cls.modify_commandline_options(parser, is_train)
                print(f"DEBUG: Called modify_commandline_options for netG: {opt.netG}")
        except ValueError as e:
            print(f"Warning: Could not find class for netG '{opt.netG}' to call modify_commandline_options. {e}")


    # Discriminator options
    if is_train and hasattr(opt, 'netD') and opt.netD:
        netD_cls_to_modify = None
        netD_lower_no_underscore = opt.netD.lower().replace('_','')

        if netD_lower_no_underscore == 'swintransformerconditionaldiscriminator' and SWIN_COND_D_IMPORTED:
            netD_cls_to_modify = SwinTransformerConditionalDiscriminator
            print(f"DEBUG: networks.modify_commandline_options trying SwinTransformerConditionalDiscriminator for netD '{opt.netD}'.")
        elif netD_lower_no_underscore.startswith('swin') and SWIN_COND_D_IMPORTED:
            netD_cls_to_modify = SwinTransformerConditionalDiscriminator
            print(f"DEBUG: networks.modify_commandline_options trying SwinTransformerConditionalDiscriminator for netD '{opt.netD}' by prefix.")
        else:
            try:
                cls_cand = find_class_in_module(opt.netD, globals())
                if issubclass(cls_cand, BaseNetwork):
                    netD_cls_to_modify = cls_cand
                    print(f"DEBUG: networks.modify_commandline_options trying {cls_cand.__name__} for netD '{opt.netD}'.")
            except ValueError:
                 pass

        if netD_cls_to_modify and hasattr(netD_cls_to_modify, 'modify_commandline_options'):
            parser = netD_cls_to_modify.modify_commandline_options(parser, is_train)
            print(f"DEBUG: Called modify_commandline_options for netD: {netD_cls_to_modify.__name__}")
        elif netD_lower_no_underscore.startswith('swin') and not netD_cls_to_modify:
            print(f"Warning: netD '{opt.netD}' looks like Swin, but class/modify_commandline_options not found/called.")
        elif not netD_cls_to_modify:
            print(f"Warning: Could not find a specific class for netD '{opt.netD}' to call modify_commandline_options. Standard D options might not be fully parsed if this D type has specific ones.")


    # VAE Encoder options (original TSIT)
    if hasattr(opt, 'use_vae') and opt.use_vae:
        if not hasattr(opt, 'netG') or opt.netG != 'RafaelGenerator': # VAE not used with RafaelGenerator in current setup
            if hasattr(ConvEncoder, 'modify_commandline_options'):
                parser = ConvEncoder.modify_commandline_options(parser, is_train)
                print("DEBUG: Called modify_commandline_options for ConvEncoder (netE).")
    return parser


def define_G(opt):
    # opt.netG should be the string name of the generator class (e.g., "RafaelGenerator")
    # find_class_in_module will find it in the `globals()` of this file.
    print(f"Attempting to define Generator: {opt.netG}")
    netG_cls = find_class_in_module(opt.netG, globals())
    netG = netG_cls(opt)

    print_network(netG)
    if len(opt.gpu_ids) > 0:
        assert (torch.cuda.is_available())
        netG.cuda(opt.gpu_ids[0]) # .cuda() is deprecated, use .to(device)
        # netG.to(torch.device(f'cuda:{opt.gpu_ids[0]}'))
    netG.init_weights(opt.init_type, opt.init_variance)
    return netG


def define_D(opt):
    netD_cls = None
    netD_name_lower_no_underscore = opt.netD.lower().replace('_', '')

    print(f"Attempting to define Discriminator for opt.netD = '{opt.netD}'")

    if netD_name_lower_no_underscore == 'swintransformerconditionaldiscriminator' and SWIN_COND_D_IMPORTED:
        print(f"Selected SwinTransformerConditionalDiscriminator for netD by full name: '{opt.netD}'.")
        netD_cls = SwinTransformerConditionalDiscriminator
    elif netD_name_lower_no_underscore.startswith("swin") and SWIN_COND_D_IMPORTED:
        print(f"Selected SwinTransformerConditionalDiscriminator for netD by prefix: '{opt.netD}'.")
        netD_cls = SwinTransformerConditionalDiscriminator
    elif opt.netD in ['MultiscaleDiscriminator', 'NLayerDiscriminator']: # Add other known D types here
        print(f"Selected {opt.netD} for netD (classic type).")
        netD_cls = find_class_in_module(opt.netD, globals())
    else:
        # Fallback for any other D specified, or if the above specific checks failed
        try:
            print(f"Attempting to find class for netD '{opt.netD}' by general search in globals()...")
            netD_cls = find_class_in_module(opt.netD, globals()) # This will raise ValueError if not found
            print(f"Found and selected {netD_cls.__name__} for netD via general search.")
        except ValueError as e:
            # This will now properly raise an error if opt.netD is not recognized.
             raise ValueError(f"Unknown or unimported Discriminator type specified by --netD: {opt.netD}. Error: {e}")


    if netD_cls is None: # Should ideally not be reached if find_class_in_module or the specific checks work
        raise ValueError(f"Could not assign a class for Discriminator type: {opt.netD}. Check imports and naming.")

    netD = netD_cls(opt)
    print_network(netD)
    if len(opt.gpu_ids) > 0 and torch.cuda.is_available():
        netD.cuda(opt.gpu_ids[0])
        # netD.to(torch.device(f'cuda:{opt.gpu_ids[0]}'))

    netD.init_weights(opt.init_type, opt.init_variance)
    return netD


def define_E(opt):
    print("Attempting to define Encoder (ConvEncoder for VAE)")
    netE = ConvEncoder(opt)
    print_network(netE)
    if len(opt.gpu_ids) > 0:
        assert (torch.cuda.is_available())
        netE.cuda(opt.gpu_ids[0])
        # netE.to(torch.device(f'cuda:{opt.gpu_ids[0]}'))
    netE.init_weights(opt.init_type, opt.init_variance)
    return netE


def print_network(net):
    if isinstance(net, list):
        actual_net_to_print = net[0]
        print(f"Printing network info for the first model in a list of {len(net)} models.")
    else:
        actual_net_to_print = net

    num_params = 0
    for param in actual_net_to_print.parameters():
        num_params += param.numel()

    print('Network [%s] was created. Total number of parameters: %.1f million. '
          'To see the architecture, do print(network).'
          % (type(actual_net_to_print).__name__, num_params / 1000000))


class Identity(torch.nn.Module):
    def forward(self, x):
        return x


def get_norm_layer(opt, norm_type='instance'):
    # norm_type is a string like 'spectralinstance' or 'instance' or 'batch'
    def get_out_channel(layer):
        if hasattr(layer, 'out_channels'):
            return getattr(layer, 'out_channels')
        if hasattr(layer, 'out_features'): # For Linear layers
            return getattr(layer, 'out_features')
        if hasattr(layer, 'weight') and layer.weight.dim() > 1: # Conv, Linear
            return layer.weight.size(0)
        raise ValueError(f"Cannot get out_channel for layer {type(layer)}")

    def add_norm_layer(layer_instance): # layer_instance is e.g. nn.Conv2d(...)
        nonlocal norm_type # Use norm_type from the outer scope
        current_norm_type = norm_type

        # Spectral normalization part
        if current_norm_type.startswith('spectral'):
            # Apply spectral norm to the layer (e.g., nn.Conv2d)
            # torch.nn.utils.spectral_norm returns the spectrally_normed_layer
            layer_with_spec_norm = torch.nn.utils.spectral_norm(layer_instance)
            # The rest of the norm_type string is the actual normalization (e.g., 'instance')
            subnorm_type = current_norm_type[len('spectral'):]
        else:
            layer_with_spec_norm = layer_instance # No spectral norm
            subnorm_type = current_norm_type

        if subnorm_type == 'none' or len(subnorm_type) == 0:
            return layer_with_spec_norm # Only spectral norm (if any) or just the layer

        # For other normalizations, remove bias from the layer_with_spec_norm
        # as it's cancelled by the normalization layer's learnable/non-learnable beta.
        if getattr(layer_with_spec_norm, 'bias', None) is not None:
            delattr(layer_with_spec_norm, 'bias')
            layer_with_spec_norm.register_parameter('bias', None)

        out_channels = get_out_channel(layer_instance) # Get channels from original layer

        if subnorm_type == 'batch':
            norm_layer_instance = nn.BatchNorm2d(out_channels, affine=True)
        elif subnorm_type == 'syncbatch':
            # Assuming SynchronizedBatchNorm2d is available in this scope
            # from models.networks.sync_batchnorm import SynchronizedBatchNorm2d # Or import globally
            try:
                from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
                norm_layer_instance = SynchronizedBatchNorm2d(out_channels, affine=True)
            except ImportError:
                print("Warning: SynchronizedBatchNorm2d not found, falling back to BatchNorm2d for 'syncbatch'.")
                norm_layer_instance = nn.BatchNorm2d(out_channels, affine=True)
        elif subnorm_type == 'instance':
            norm_layer_instance = nn.InstanceNorm2d(out_channels, affine=False) # Original TSIT used affine=False
        elif subnorm_type.startswith('fade'): # FADE is handled differently, usually within the block
            print(
                f"Warning: Norm type '{subnorm_type}' looks like FADE. "
                "This generic get_norm_layer is not for FADE. FADE should be part of specific blocks. "
                "Returning layer without additional FADE-like normalization here.")
            return layer_with_spec_norm # Return layer possibly with spectral norm
        else:
            raise ValueError('normalization layer %s (from %s) is not recognized' % (subnorm_type, norm_type))

        return nn.Sequential(layer_with_spec_norm, norm_layer_instance)

    return add_norm_layer