File size: 12,091 Bytes
0f52c9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | import torch
from models.networks.base_network import BaseNetwork
from models.networks.loss import *
from models.networks.discriminator import *
from models.networks.encoder import ConvEncoder
from models.networks.generator import TSITGenerator, Pix2PixHDGenerator
from models.networks.rafael_generator import RafaelGenerator
import torch.nn as nn
from .style_encoder import SimpleStyleEncoder
SWIN_COND_D_IMPORTED = False
SwinTransformerConditionalDiscriminator = None
try:
from models.networks.swin_discriminator import SwinTransformerConditionalDiscriminator
SWIN_COND_D_IMPORTED = True
print("Successfully imported SwinTransformerConditionalDiscriminator.")
except ImportError as e:
print(f"Warning: Could not import SwinTransformerConditionalDiscriminator: {e}. Swin D will not be available.")
except Exception as e:
print(f"Warning: An unexpected error occurred while importing SwinTransformerConditionalDiscriminator: {e}")
def find_class_in_module(target_cls_name, module_globals):
target_cls_name_lower_no_underscore = target_cls_name.lower().replace('_', '')
for name, cls_obj in module_globals.items():
if isinstance(cls_obj, type) and name.lower().replace('_', '') == target_cls_name_lower_no_underscore:
return cls_obj
available_classes = [name for name, obj in module_globals.items() if isinstance(obj, type)]
print(
f"Error: Class '{target_cls_name_lower_no_underscore}' (derived from '{target_cls_name}') not found in module_globals.")
print(f"Available classes (case-insensitive, no underscore): {[ac.lower().replace('_', '') for ac in available_classes]}")
print(f"Original available classes: {available_classes}")
raise ValueError(
f"In current module, there should be a class named '{target_cls_name}' (comparison is case-insensitive and ignores underscores)."
)
def modify_commandline_options(parser, is_train):
opt, _ = parser.parse_known_args()
# Generator options
if hasattr(opt, 'netG') and opt.netG:
try:
netG_cls = find_class_in_module(opt.netG, globals())
if hasattr(netG_cls, 'modify_commandline_options'):
parser = netG_cls.modify_commandline_options(parser, is_train)
print(f"DEBUG: Called modify_commandline_options for netG: {opt.netG}")
except ValueError as e:
print(f"Warning: Could not find class for netG '{opt.netG}' to call modify_commandline_options. {e}")
# Discriminator options
if is_train and hasattr(opt, 'netD') and opt.netD:
netD_cls_to_modify = None
netD_lower_no_underscore = opt.netD.lower().replace('_','')
if netD_lower_no_underscore == 'swintransformerconditionaldiscriminator' and SWIN_COND_D_IMPORTED:
netD_cls_to_modify = SwinTransformerConditionalDiscriminator
print(f"DEBUG: networks.modify_commandline_options trying SwinTransformerConditionalDiscriminator for netD '{opt.netD}'.")
elif netD_lower_no_underscore.startswith('swin') and SWIN_COND_D_IMPORTED:
netD_cls_to_modify = SwinTransformerConditionalDiscriminator
print(f"DEBUG: networks.modify_commandline_options trying SwinTransformerConditionalDiscriminator for netD '{opt.netD}' by prefix.")
else:
try:
cls_cand = find_class_in_module(opt.netD, globals())
if issubclass(cls_cand, BaseNetwork):
netD_cls_to_modify = cls_cand
print(f"DEBUG: networks.modify_commandline_options trying {cls_cand.__name__} for netD '{opt.netD}'.")
except ValueError:
pass
if netD_cls_to_modify and hasattr(netD_cls_to_modify, 'modify_commandline_options'):
parser = netD_cls_to_modify.modify_commandline_options(parser, is_train)
print(f"DEBUG: Called modify_commandline_options for netD: {netD_cls_to_modify.__name__}")
elif netD_lower_no_underscore.startswith('swin') and not netD_cls_to_modify:
print(f"Warning: netD '{opt.netD}' looks like Swin, but class/modify_commandline_options not found/called.")
elif not netD_cls_to_modify:
print(f"Warning: Could not find a specific class for netD '{opt.netD}' to call modify_commandline_options. Standard D options might not be fully parsed if this D type has specific ones.")
# VAE Encoder options (original TSIT)
if hasattr(opt, 'use_vae') and opt.use_vae:
if not hasattr(opt, 'netG') or opt.netG != 'RafaelGenerator': # VAE not used with RafaelGenerator in current setup
if hasattr(ConvEncoder, 'modify_commandline_options'):
parser = ConvEncoder.modify_commandline_options(parser, is_train)
print("DEBUG: Called modify_commandline_options for ConvEncoder (netE).")
return parser
def define_G(opt):
# opt.netG should be the string name of the generator class (e.g., "RafaelGenerator")
# find_class_in_module will find it in the `globals()` of this file.
print(f"Attempting to define Generator: {opt.netG}")
netG_cls = find_class_in_module(opt.netG, globals())
netG = netG_cls(opt)
print_network(netG)
if len(opt.gpu_ids) > 0:
assert (torch.cuda.is_available())
netG.cuda(opt.gpu_ids[0]) # .cuda() is deprecated, use .to(device)
# netG.to(torch.device(f'cuda:{opt.gpu_ids[0]}'))
netG.init_weights(opt.init_type, opt.init_variance)
return netG
def define_D(opt):
netD_cls = None
netD_name_lower_no_underscore = opt.netD.lower().replace('_', '')
print(f"Attempting to define Discriminator for opt.netD = '{opt.netD}'")
if netD_name_lower_no_underscore == 'swintransformerconditionaldiscriminator' and SWIN_COND_D_IMPORTED:
print(f"Selected SwinTransformerConditionalDiscriminator for netD by full name: '{opt.netD}'.")
netD_cls = SwinTransformerConditionalDiscriminator
elif netD_name_lower_no_underscore.startswith("swin") and SWIN_COND_D_IMPORTED:
print(f"Selected SwinTransformerConditionalDiscriminator for netD by prefix: '{opt.netD}'.")
netD_cls = SwinTransformerConditionalDiscriminator
elif opt.netD in ['MultiscaleDiscriminator', 'NLayerDiscriminator']: # Add other known D types here
print(f"Selected {opt.netD} for netD (classic type).")
netD_cls = find_class_in_module(opt.netD, globals())
else:
# Fallback for any other D specified, or if the above specific checks failed
try:
print(f"Attempting to find class for netD '{opt.netD}' by general search in globals()...")
netD_cls = find_class_in_module(opt.netD, globals()) # This will raise ValueError if not found
print(f"Found and selected {netD_cls.__name__} for netD via general search.")
except ValueError as e:
# This will now properly raise an error if opt.netD is not recognized.
raise ValueError(f"Unknown or unimported Discriminator type specified by --netD: {opt.netD}. Error: {e}")
if netD_cls is None: # Should ideally not be reached if find_class_in_module or the specific checks work
raise ValueError(f"Could not assign a class for Discriminator type: {opt.netD}. Check imports and naming.")
netD = netD_cls(opt)
print_network(netD)
if len(opt.gpu_ids) > 0 and torch.cuda.is_available():
netD.cuda(opt.gpu_ids[0])
# netD.to(torch.device(f'cuda:{opt.gpu_ids[0]}'))
netD.init_weights(opt.init_type, opt.init_variance)
return netD
def define_E(opt):
print("Attempting to define Encoder (ConvEncoder for VAE)")
netE = ConvEncoder(opt)
print_network(netE)
if len(opt.gpu_ids) > 0:
assert (torch.cuda.is_available())
netE.cuda(opt.gpu_ids[0])
# netE.to(torch.device(f'cuda:{opt.gpu_ids[0]}'))
netE.init_weights(opt.init_type, opt.init_variance)
return netE
def print_network(net):
if isinstance(net, list):
actual_net_to_print = net[0]
print(f"Printing network info for the first model in a list of {len(net)} models.")
else:
actual_net_to_print = net
num_params = 0
for param in actual_net_to_print.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).'
% (type(actual_net_to_print).__name__, num_params / 1000000))
class Identity(torch.nn.Module):
def forward(self, x):
return x
def get_norm_layer(opt, norm_type='instance'):
# norm_type is a string like 'spectralinstance' or 'instance' or 'batch'
def get_out_channel(layer):
if hasattr(layer, 'out_channels'):
return getattr(layer, 'out_channels')
if hasattr(layer, 'out_features'): # For Linear layers
return getattr(layer, 'out_features')
if hasattr(layer, 'weight') and layer.weight.dim() > 1: # Conv, Linear
return layer.weight.size(0)
raise ValueError(f"Cannot get out_channel for layer {type(layer)}")
def add_norm_layer(layer_instance): # layer_instance is e.g. nn.Conv2d(...)
nonlocal norm_type # Use norm_type from the outer scope
current_norm_type = norm_type
# Spectral normalization part
if current_norm_type.startswith('spectral'):
# Apply spectral norm to the layer (e.g., nn.Conv2d)
# torch.nn.utils.spectral_norm returns the spectrally_normed_layer
layer_with_spec_norm = torch.nn.utils.spectral_norm(layer_instance)
# The rest of the norm_type string is the actual normalization (e.g., 'instance')
subnorm_type = current_norm_type[len('spectral'):]
else:
layer_with_spec_norm = layer_instance # No spectral norm
subnorm_type = current_norm_type
if subnorm_type == 'none' or len(subnorm_type) == 0:
return layer_with_spec_norm # Only spectral norm (if any) or just the layer
# For other normalizations, remove bias from the layer_with_spec_norm
# as it's cancelled by the normalization layer's learnable/non-learnable beta.
if getattr(layer_with_spec_norm, 'bias', None) is not None:
delattr(layer_with_spec_norm, 'bias')
layer_with_spec_norm.register_parameter('bias', None)
out_channels = get_out_channel(layer_instance) # Get channels from original layer
if subnorm_type == 'batch':
norm_layer_instance = nn.BatchNorm2d(out_channels, affine=True)
elif subnorm_type == 'syncbatch':
# Assuming SynchronizedBatchNorm2d is available in this scope
# from models.networks.sync_batchnorm import SynchronizedBatchNorm2d # Or import globally
try:
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
norm_layer_instance = SynchronizedBatchNorm2d(out_channels, affine=True)
except ImportError:
print("Warning: SynchronizedBatchNorm2d not found, falling back to BatchNorm2d for 'syncbatch'.")
norm_layer_instance = nn.BatchNorm2d(out_channels, affine=True)
elif subnorm_type == 'instance':
norm_layer_instance = nn.InstanceNorm2d(out_channels, affine=False) # Original TSIT used affine=False
elif subnorm_type.startswith('fade'): # FADE is handled differently, usually within the block
print(
f"Warning: Norm type '{subnorm_type}' looks like FADE. "
"This generic get_norm_layer is not for FADE. FADE should be part of specific blocks. "
"Returning layer without additional FADE-like normalization here.")
return layer_with_spec_norm # Return layer possibly with spectral norm
else:
raise ValueError('normalization layer %s (from %s) is not recognized' % (subnorm_type, norm_type))
return nn.Sequential(layer_with_spec_norm, norm_layer_instance)
return add_norm_layer |