File size: 2,868 Bytes
61d360d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import numpy as np
import logging
from logging_utils import setup_logger

# Configure logging
logger = setup_logger(__name__)

# Try to import the necessary modules, use fallback if not available
try:
    from models.fpn_inception import FPNInception
    INCEPTION_AVAILABLE = True
    logger.info("Successfully imported FPNInception model")
except ImportError as e:
    logger.error(f"Error importing FPNInception: {str(e)}")
    INCEPTION_AVAILABLE = False

# Simple fallback model for testing purposes
class FallbackDeblurModel(nn.Module):
    def __init__(self):
        super().__init__()
        logger.info("Initializing fallback model for testing")
        # Simple autoencoder-like structure
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        # Simple pass-through for testing
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return torch.clamp(decoded + x, min=-1, max=1)

def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

def get_generator(model_config):
    if isinstance(model_config, str):
        generator_name = model_config
    else:
        generator_name = model_config['g_name']
    
    # Try to use FPNInception if available
    if generator_name == 'fpn_inception':
        if INCEPTION_AVAILABLE:
            try:
                logger.info("Creating FPNInception model")
                model_g = FPNInception(norm_layer=get_norm_layer(norm_type='instance'))
                return nn.DataParallel(model_g)
            except Exception as e:
                logger.error(f"Error creating FPNInception model: {str(e)}")
                logger.warning("Falling back to simple model for testing")
                return FallbackDeblurModel()
        else:
            logger.warning("FPNInception not available, using fallback model")
            return FallbackDeblurModel()
    else:
        raise ValueError("Generator Network [%s] not recognized." % generator_name)