| import torch |
| import torch.nn as nn |
| from torch.nn import init |
| import functools |
| from torch.autograd import Variable |
| import numpy as np |
| import logging |
| from logging_utils import setup_logger |
|
|
| |
| logger = setup_logger(__name__) |
|
|
| |
| try: |
| from models.fpn_inception import FPNInception |
| INCEPTION_AVAILABLE = True |
| logger.info("Successfully imported FPNInception model") |
| except ImportError as e: |
| logger.error(f"Error importing FPNInception: {str(e)}") |
| INCEPTION_AVAILABLE = False |
|
|
| |
| class FallbackDeblurModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| logger.info("Initializing fallback model for testing") |
| |
| self.encoder = nn.Sequential( |
| nn.Conv2d(3, 64, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(2) |
| ) |
| |
| self.decoder = nn.Sequential( |
| nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(64, 3, kernel_size=3, padding=1), |
| nn.Tanh() |
| ) |
| |
| def forward(self, x): |
| |
| encoded = self.encoder(x) |
| decoded = self.decoder(encoded) |
| return torch.clamp(decoded + x, min=-1, max=1) |
|
|
| def get_norm_layer(norm_type='instance'): |
| if norm_type == 'batch': |
| norm_layer = functools.partial(nn.BatchNorm2d, affine=True) |
| elif norm_type == 'instance': |
| norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) |
| else: |
| raise NotImplementedError('normalization layer [%s] is not found' % norm_type) |
| return norm_layer |
|
|
| def get_generator(model_config): |
| if isinstance(model_config, str): |
| generator_name = model_config |
| else: |
| generator_name = model_config['g_name'] |
| |
| |
| if generator_name == 'fpn_inception': |
| if INCEPTION_AVAILABLE: |
| try: |
| logger.info("Creating FPNInception model") |
| model_g = FPNInception(norm_layer=get_norm_layer(norm_type='instance')) |
| return nn.DataParallel(model_g) |
| except Exception as e: |
| logger.error(f"Error creating FPNInception model: {str(e)}") |
| logger.warning("Falling back to simple model for testing") |
| return FallbackDeblurModel() |
| else: |
| logger.warning("FPNInception not available, using fallback model") |
| return FallbackDeblurModel() |
| else: |
| raise ValueError("Generator Network [%s] not recognized." % generator_name) |
|
|