File size: 7,831 Bytes
fd5c0a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import torch.nn as nn
import math
import os

# --- ResidualBlock, Upsampler, and Generator classes remain the same ---
class ResidualBlock(nn.Module):
    def __init__(self, num_features, kernel_size=3, bn=False, act=nn.ReLU(True), res_scale=1.0):
        super(ResidualBlock, self).__init__()
        padding = kernel_size // 2
        m = []
        m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding))
        if bn: m.append(nn.BatchNorm2d(num_features))
        m.append(act)
        m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding))
        if bn: m.append(nn.BatchNorm2d(num_features))
        self.body = nn.Sequential(*m)
        self.res_scale = res_scale
    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x
        return res

class Upsampler(nn.Module):
    def __init__(self, scale_factor, num_features, act=nn.ReLU(True)):
        super(Upsampler, self).__init__()
        m = []
        m.append(nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, padding=1))
        m.append(nn.PixelShuffle(scale_factor))
        if act: m.append(act)
        self.body = nn.Sequential(*m)
    def forward(self, x):
        return self.body(x)

class Generator(nn.Module):
    def __init__(self, scale_factor=4, in_channels=3, out_channels=3, num_features=64, num_res_blocks=16, res_scale=1.0):
        super(Generator, self).__init__()
        self.scale_factor = scale_factor
        act = nn.ReLU(True)
        self.head = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
        res_blocks = [ResidualBlock(num_features, kernel_size=3, act=act, res_scale=res_scale) for _ in range(num_res_blocks)]
        res_blocks.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1))
        self.body = nn.Sequential(*res_blocks)
        m_tail = []
        if (scale_factor & (scale_factor - 1)) == 0:
            for _ in range(int(math.log2(scale_factor))):
                m_tail.append(Upsampler(scale_factor=2, num_features=num_features, act=None))
        elif scale_factor == 3:
             m_tail.append(Upsampler(scale_factor=3, num_features=num_features, act=None))
        else:
            raise NotImplementedError(f"Scale factor {scale_factor} not directly supported by this simple upsampler.")
        self.tail = nn.Sequential(*m_tail)
        self.final_conv = nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1)

    def forward(self, lr_img):
        x = self.head(lr_img)
        res = self.body(x)
        res += x
        x = self.tail(res)
        x = self.final_conv(x)
        return x

# +++ NEW Discriminator Class +++
class Discriminator(nn.Module):
    """
    Simple CNN Discriminator Network (PatchGAN style is common but this is simpler).
    Takes an image (real HR or generated SR) and outputs a single logit.
    """
    def __init__(self, in_channels=3, num_features_start=64, num_blocks=4):
        super(Discriminator, self).__init__()

        # Initial block
        layers = [
            nn.Conv2d(in_channels, num_features_start, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        current_features = num_features_start
        for i in range(num_blocks):
            stride = 1 if i % 2 == 0 else 2 # Downsample every other block
            next_features = current_features * 2 if stride == 2 else current_features
            layers.extend([
                nn.Conv2d(current_features, next_features, kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(next_features), # BatchNorm is common in discriminators
                nn.LeakyReLU(0.2, inplace=True)
            ])
            current_features = next_features

        self.features = nn.Sequential(*layers)

        # Classifier part - adjust input features based on final conv output size
        # We need to know the output size of the feature extractor to define the Linear layer.
        # Using AdaptiveAvgPool2d makes it independent of the input image size.
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(current_features, 100), # Example intermediate size
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(100, 1) # Output a single logit (no sigmoid here)
        )

    def forward(self, img):
        """
        Args:
            img (torch.Tensor): Input image tensor (B, C, H, W), either real HR or fake SR.
        Returns:
            torch.Tensor: Output logits (B, 1). Higher values -> more likely "real".
        """
        batch_size = img.size(0)
        features = self.features(img)
        pooled = self.avgpool(features)
        # Flatten the output of avgpool for the linear layer
        pooled = pooled.view(batch_size, -1)
        output = self.classifier(pooled)
        return output

# --- Main block for testing and saving ---
if __name__ == '__main__':
    # --- Generator Test (as before) ---
    SCALE = 4
    GEN_FEATURES = 64
    GEN_RES_BLOCKS = 8
    save_dir = "saved_models"
    os.makedirs(save_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Dummy LR input for Generator
    gen_batch_size = 1
    lr_height = 32
    lr_width = 32
    in_channels = 3
    dummy_lr = torch.randn(gen_batch_size, in_channels, lr_height, lr_width).to(device)
    print(f"Dummy LR input shape (Generator): {dummy_lr.shape}")

    generator = Generator(scale_factor=SCALE, num_features=GEN_FEATURES, num_res_blocks=GEN_RES_BLOCKS).to(device)
    generator.eval()
    with torch.no_grad():
        output_sr = generator(dummy_lr)
    print(f"Output SR shape (Generator): {output_sr.shape}")
    # ... (rest of generator verification and saving code remains here) ...
    print("\nGenerator definition test successful!")
    num_params_gen = sum(p.numel() for p in generator.parameters() if p.requires_grad)
    print(f"Generator - Number of trainable parameters: {num_params_gen:,}")
    # ... (Saving code as before) ...

    print("\n--- Testing Discriminator ---")
    # --- Discriminator Test ---
    DISC_FEATURES = 64 # Starting features for discriminator
    DISC_BLOCKS = 3   # Number of conv blocks in discriminator

    # Dummy HR/SR input for Discriminator (must match Generator's output size)
    disc_batch_size = 4 # Can be different from generator test batch size
    hr_height = output_sr.shape[2] # Use the calculated HR height
    hr_width = output_sr.shape[3]  # Use the calculated HR width
    dummy_hr = torch.randn(disc_batch_size, in_channels, hr_height, hr_width).to(device)
    print(f"Dummy HR/SR input shape (Discriminator): {dummy_hr.shape}")

    # Instantiate the Discriminator
    discriminator = Discriminator(in_channels=in_channels,
                                num_features_start=DISC_FEATURES,
                                num_blocks=DISC_BLOCKS).to(device)
    discriminator.eval() # Set to evaluation mode for testing

    # print(discriminator) # Optional: Print structure

    # Perform a forward pass
    with torch.no_grad():
        output_logits = discriminator(dummy_hr)

    print(f"Output Logits shape (Discriminator): {output_logits.shape}")

    # Verify output shape
    expected_disc_shape = (disc_batch_size, 1)
    assert output_logits.shape == expected_disc_shape, \
        f"Discriminator output shape mismatch! Expected {expected_disc_shape}, got {output_logits.shape}"

    print("Discriminator definition test successful!")

    # Optional: Count parameters
    num_params_disc = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
    print(f"Discriminator - Number of trainable parameters: {num_params_disc:,}")