sayed99's picture
initialized both deblurer
61d360d
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import numpy as np
import logging
from logging_utils import setup_logger
# Configure logging
logger = setup_logger(__name__)
# Try to import the necessary modules, use fallback if not available
try:
from models.fpn_inception import FPNInception
INCEPTION_AVAILABLE = True
logger.info("Successfully imported FPNInception model")
except ImportError as e:
logger.error(f"Error importing FPNInception: {str(e)}")
INCEPTION_AVAILABLE = False
# Simple fallback model for testing purposes
class FallbackDeblurModel(nn.Module):
def __init__(self):
super().__init__()
logger.info("Initializing fallback model for testing")
# Simple autoencoder-like structure
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 3, kernel_size=3, padding=1),
nn.Tanh()
)
def forward(self, x):
# Simple pass-through for testing
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return torch.clamp(decoded + x, min=-1, max=1)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_generator(model_config):
if isinstance(model_config, str):
generator_name = model_config
else:
generator_name = model_config['g_name']
# Try to use FPNInception if available
if generator_name == 'fpn_inception':
if INCEPTION_AVAILABLE:
try:
logger.info("Creating FPNInception model")
model_g = FPNInception(norm_layer=get_norm_layer(norm_type='instance'))
return nn.DataParallel(model_g)
except Exception as e:
logger.error(f"Error creating FPNInception model: {str(e)}")
logger.warning("Falling back to simple model for testing")
return FallbackDeblurModel()
else:
logger.warning("FPNInception not available, using fallback model")
return FallbackDeblurModel()
else:
raise ValueError("Generator Network [%s] not recognized." % generator_name)