|
|
import torch |
|
|
from models.networks.base_network import BaseNetwork |
|
|
from models.networks.loss import * |
|
|
from models.networks.discriminator import * |
|
|
from models.networks.encoder import ConvEncoder |
|
|
from models.networks.generator import TSITGenerator, Pix2PixHDGenerator |
|
|
from models.networks.rafael_generator import RafaelGenerator |
|
|
import torch.nn as nn |
|
|
from .style_encoder import SimpleStyleEncoder |
|
|
|
|
|
SWIN_COND_D_IMPORTED = False |
|
|
SwinTransformerConditionalDiscriminator = None |
|
|
try: |
|
|
from models.networks.swin_discriminator import SwinTransformerConditionalDiscriminator |
|
|
SWIN_COND_D_IMPORTED = True |
|
|
print("Successfully imported SwinTransformerConditionalDiscriminator.") |
|
|
except ImportError as e: |
|
|
print(f"Warning: Could not import SwinTransformerConditionalDiscriminator: {e}. Swin D will not be available.") |
|
|
except Exception as e: |
|
|
print(f"Warning: An unexpected error occurred while importing SwinTransformerConditionalDiscriminator: {e}") |
|
|
|
|
|
|
|
|
def find_class_in_module(target_cls_name, module_globals): |
|
|
target_cls_name_lower_no_underscore = target_cls_name.lower().replace('_', '') |
|
|
for name, cls_obj in module_globals.items(): |
|
|
if isinstance(cls_obj, type) and name.lower().replace('_', '') == target_cls_name_lower_no_underscore: |
|
|
return cls_obj |
|
|
|
|
|
available_classes = [name for name, obj in module_globals.items() if isinstance(obj, type)] |
|
|
print( |
|
|
f"Error: Class '{target_cls_name_lower_no_underscore}' (derived from '{target_cls_name}') not found in module_globals.") |
|
|
print(f"Available classes (case-insensitive, no underscore): {[ac.lower().replace('_', '') for ac in available_classes]}") |
|
|
print(f"Original available classes: {available_classes}") |
|
|
raise ValueError( |
|
|
f"In current module, there should be a class named '{target_cls_name}' (comparison is case-insensitive and ignores underscores)." |
|
|
) |
|
|
|
|
|
|
|
|
def modify_commandline_options(parser, is_train): |
|
|
opt, _ = parser.parse_known_args() |
|
|
|
|
|
|
|
|
if hasattr(opt, 'netG') and opt.netG: |
|
|
try: |
|
|
netG_cls = find_class_in_module(opt.netG, globals()) |
|
|
if hasattr(netG_cls, 'modify_commandline_options'): |
|
|
parser = netG_cls.modify_commandline_options(parser, is_train) |
|
|
print(f"DEBUG: Called modify_commandline_options for netG: {opt.netG}") |
|
|
except ValueError as e: |
|
|
print(f"Warning: Could not find class for netG '{opt.netG}' to call modify_commandline_options. {e}") |
|
|
|
|
|
|
|
|
|
|
|
if is_train and hasattr(opt, 'netD') and opt.netD: |
|
|
netD_cls_to_modify = None |
|
|
netD_lower_no_underscore = opt.netD.lower().replace('_','') |
|
|
|
|
|
if netD_lower_no_underscore == 'swintransformerconditionaldiscriminator' and SWIN_COND_D_IMPORTED: |
|
|
netD_cls_to_modify = SwinTransformerConditionalDiscriminator |
|
|
print(f"DEBUG: networks.modify_commandline_options trying SwinTransformerConditionalDiscriminator for netD '{opt.netD}'.") |
|
|
elif netD_lower_no_underscore.startswith('swin') and SWIN_COND_D_IMPORTED: |
|
|
netD_cls_to_modify = SwinTransformerConditionalDiscriminator |
|
|
print(f"DEBUG: networks.modify_commandline_options trying SwinTransformerConditionalDiscriminator for netD '{opt.netD}' by prefix.") |
|
|
else: |
|
|
try: |
|
|
cls_cand = find_class_in_module(opt.netD, globals()) |
|
|
if issubclass(cls_cand, BaseNetwork): |
|
|
netD_cls_to_modify = cls_cand |
|
|
print(f"DEBUG: networks.modify_commandline_options trying {cls_cand.__name__} for netD '{opt.netD}'.") |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
if netD_cls_to_modify and hasattr(netD_cls_to_modify, 'modify_commandline_options'): |
|
|
parser = netD_cls_to_modify.modify_commandline_options(parser, is_train) |
|
|
print(f"DEBUG: Called modify_commandline_options for netD: {netD_cls_to_modify.__name__}") |
|
|
elif netD_lower_no_underscore.startswith('swin') and not netD_cls_to_modify: |
|
|
print(f"Warning: netD '{opt.netD}' looks like Swin, but class/modify_commandline_options not found/called.") |
|
|
elif not netD_cls_to_modify: |
|
|
print(f"Warning: Could not find a specific class for netD '{opt.netD}' to call modify_commandline_options. Standard D options might not be fully parsed if this D type has specific ones.") |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(opt, 'use_vae') and opt.use_vae: |
|
|
if not hasattr(opt, 'netG') or opt.netG != 'RafaelGenerator': |
|
|
if hasattr(ConvEncoder, 'modify_commandline_options'): |
|
|
parser = ConvEncoder.modify_commandline_options(parser, is_train) |
|
|
print("DEBUG: Called modify_commandline_options for ConvEncoder (netE).") |
|
|
return parser |
|
|
|
|
|
|
|
|
def define_G(opt): |
|
|
|
|
|
|
|
|
print(f"Attempting to define Generator: {opt.netG}") |
|
|
netG_cls = find_class_in_module(opt.netG, globals()) |
|
|
netG = netG_cls(opt) |
|
|
|
|
|
print_network(netG) |
|
|
if len(opt.gpu_ids) > 0: |
|
|
assert (torch.cuda.is_available()) |
|
|
netG.cuda(opt.gpu_ids[0]) |
|
|
|
|
|
netG.init_weights(opt.init_type, opt.init_variance) |
|
|
return netG |
|
|
|
|
|
|
|
|
def define_D(opt): |
|
|
netD_cls = None |
|
|
netD_name_lower_no_underscore = opt.netD.lower().replace('_', '') |
|
|
|
|
|
print(f"Attempting to define Discriminator for opt.netD = '{opt.netD}'") |
|
|
|
|
|
if netD_name_lower_no_underscore == 'swintransformerconditionaldiscriminator' and SWIN_COND_D_IMPORTED: |
|
|
print(f"Selected SwinTransformerConditionalDiscriminator for netD by full name: '{opt.netD}'.") |
|
|
netD_cls = SwinTransformerConditionalDiscriminator |
|
|
elif netD_name_lower_no_underscore.startswith("swin") and SWIN_COND_D_IMPORTED: |
|
|
print(f"Selected SwinTransformerConditionalDiscriminator for netD by prefix: '{opt.netD}'.") |
|
|
netD_cls = SwinTransformerConditionalDiscriminator |
|
|
elif opt.netD in ['MultiscaleDiscriminator', 'NLayerDiscriminator']: |
|
|
print(f"Selected {opt.netD} for netD (classic type).") |
|
|
netD_cls = find_class_in_module(opt.netD, globals()) |
|
|
else: |
|
|
|
|
|
try: |
|
|
print(f"Attempting to find class for netD '{opt.netD}' by general search in globals()...") |
|
|
netD_cls = find_class_in_module(opt.netD, globals()) |
|
|
print(f"Found and selected {netD_cls.__name__} for netD via general search.") |
|
|
except ValueError as e: |
|
|
|
|
|
raise ValueError(f"Unknown or unimported Discriminator type specified by --netD: {opt.netD}. Error: {e}") |
|
|
|
|
|
|
|
|
if netD_cls is None: |
|
|
raise ValueError(f"Could not assign a class for Discriminator type: {opt.netD}. Check imports and naming.") |
|
|
|
|
|
netD = netD_cls(opt) |
|
|
print_network(netD) |
|
|
if len(opt.gpu_ids) > 0 and torch.cuda.is_available(): |
|
|
netD.cuda(opt.gpu_ids[0]) |
|
|
|
|
|
|
|
|
netD.init_weights(opt.init_type, opt.init_variance) |
|
|
return netD |
|
|
|
|
|
|
|
|
def define_E(opt): |
|
|
print("Attempting to define Encoder (ConvEncoder for VAE)") |
|
|
netE = ConvEncoder(opt) |
|
|
print_network(netE) |
|
|
if len(opt.gpu_ids) > 0: |
|
|
assert (torch.cuda.is_available()) |
|
|
netE.cuda(opt.gpu_ids[0]) |
|
|
|
|
|
netE.init_weights(opt.init_type, opt.init_variance) |
|
|
return netE |
|
|
|
|
|
|
|
|
def print_network(net): |
|
|
if isinstance(net, list): |
|
|
actual_net_to_print = net[0] |
|
|
print(f"Printing network info for the first model in a list of {len(net)} models.") |
|
|
else: |
|
|
actual_net_to_print = net |
|
|
|
|
|
num_params = 0 |
|
|
for param in actual_net_to_print.parameters(): |
|
|
num_params += param.numel() |
|
|
|
|
|
print('Network [%s] was created. Total number of parameters: %.1f million. ' |
|
|
'To see the architecture, do print(network).' |
|
|
% (type(actual_net_to_print).__name__, num_params / 1000000)) |
|
|
|
|
|
|
|
|
class Identity(torch.nn.Module): |
|
|
def forward(self, x): |
|
|
return x |
|
|
|
|
|
|
|
|
def get_norm_layer(opt, norm_type='instance'): |
|
|
|
|
|
def get_out_channel(layer): |
|
|
if hasattr(layer, 'out_channels'): |
|
|
return getattr(layer, 'out_channels') |
|
|
if hasattr(layer, 'out_features'): |
|
|
return getattr(layer, 'out_features') |
|
|
if hasattr(layer, 'weight') and layer.weight.dim() > 1: |
|
|
return layer.weight.size(0) |
|
|
raise ValueError(f"Cannot get out_channel for layer {type(layer)}") |
|
|
|
|
|
def add_norm_layer(layer_instance): |
|
|
nonlocal norm_type |
|
|
current_norm_type = norm_type |
|
|
|
|
|
|
|
|
if current_norm_type.startswith('spectral'): |
|
|
|
|
|
|
|
|
layer_with_spec_norm = torch.nn.utils.spectral_norm(layer_instance) |
|
|
|
|
|
subnorm_type = current_norm_type[len('spectral'):] |
|
|
else: |
|
|
layer_with_spec_norm = layer_instance |
|
|
subnorm_type = current_norm_type |
|
|
|
|
|
if subnorm_type == 'none' or len(subnorm_type) == 0: |
|
|
return layer_with_spec_norm |
|
|
|
|
|
|
|
|
|
|
|
if getattr(layer_with_spec_norm, 'bias', None) is not None: |
|
|
delattr(layer_with_spec_norm, 'bias') |
|
|
layer_with_spec_norm.register_parameter('bias', None) |
|
|
|
|
|
out_channels = get_out_channel(layer_instance) |
|
|
|
|
|
if subnorm_type == 'batch': |
|
|
norm_layer_instance = nn.BatchNorm2d(out_channels, affine=True) |
|
|
elif subnorm_type == 'syncbatch': |
|
|
|
|
|
|
|
|
try: |
|
|
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d |
|
|
norm_layer_instance = SynchronizedBatchNorm2d(out_channels, affine=True) |
|
|
except ImportError: |
|
|
print("Warning: SynchronizedBatchNorm2d not found, falling back to BatchNorm2d for 'syncbatch'.") |
|
|
norm_layer_instance = nn.BatchNorm2d(out_channels, affine=True) |
|
|
elif subnorm_type == 'instance': |
|
|
norm_layer_instance = nn.InstanceNorm2d(out_channels, affine=False) |
|
|
elif subnorm_type.startswith('fade'): |
|
|
print( |
|
|
f"Warning: Norm type '{subnorm_type}' looks like FADE. " |
|
|
"This generic get_norm_layer is not for FADE. FADE should be part of specific blocks. " |
|
|
"Returning layer without additional FADE-like normalization here.") |
|
|
return layer_with_spec_norm |
|
|
else: |
|
|
raise ValueError('normalization layer %s (from %s) is not recognized' % (subnorm_type, norm_type)) |
|
|
|
|
|
return nn.Sequential(layer_with_spec_norm, norm_layer_instance) |
|
|
|
|
|
return add_norm_layer |