File size: 5,622 Bytes
fab085c 1e05154 fab085c 1e05154 fab085c 064fcab d6fab95 064fcab d6fab95 fab085c d6fab95 fab085c d6fab95 fab085c d6fab95 fab085c 064fcab fab085c c268254 fab085c 064fcab fab085c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import torch
import torch.nn as nn
import os
import torch.nn.functional as F
# 1. SRCNN
class SRCNN(nn.Module):
def __init__(self):
super(SRCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# SRCNN typically takes an already upscaled (bicubic) input, but we can structure it safely
if x.shape[2:] != (x.shape[2]*4, x.shape[3]*4):
x = F.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False)
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x
# 3. Satlas (Placeholder architecture)
class SatlasSR(nn.Module):
def __init__(self):
super(SatlasSR, self).__init__()
# NOTE: satlaspretrain models are Swin feature backbones, not native SuperResolution headers.
# Randomly initialized wrapper convolutions will cause severe output noise (fucked channels).
# For demonstration without a trained SR head, this placeholder passes safely via bicubic upsampling.
pass
def forward(self, x):
return F.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False)
# 4. ESRGAN (RRDBNet)
class ResidualDenseBlock(nn.Module):
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self):
super(RRDBNet, self).__init__()
num_in_ch=3
num_out_ch=3
num_feat=64
num_block=23
num_grow_ch=32
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
feat = self.conv_first(x)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
def load_model(model_name, model_path, device):
if not os.path.exists(model_path):
return None
if model_name == "srcnn":
model = SRCNN()
elif model_name == "satlas":
model = SatlasSR()
elif model_name == "esrgan":
model = RRDBNet()
else:
return None
try:
state_dict = torch.load(model_path, map_location=device)
# Extract params_ema if found (often standard for pretrained models like RealESRGAN)
if 'params_ema' in state_dict:
state_dict = state_dict['params_ema']
elif 'params' in state_dict:
state_dict = state_dict['params']
# Attempt minimal state dict loading.
# Strict=False to bypass mismatches in our placeholder architectures compared to actual weights
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)
return model
except Exception as e:
print(f"Error loading {model_name}: {e}")
return None
def get_available_models(model_dir="models", device="cpu"):
models = {}
paths = {
"srcnn": os.path.join(model_dir, "srcnn_x4.pth"),
"satlas": os.path.join(model_dir, "aerial_swinb_si.pth"),
"esrgan": os.path.join(model_dir, "RealESRGAN_x4plus.pth")
}
for name, path in paths.items():
if os.path.exists(path):
print(f"Loading {name}...")
model = load_model(name, path, device)
if model is not None:
models[name] = model
else:
print(f"Model file for {name} not found at {path}. Skipping.")
return models
|