Upload inference_runner.py
Browse files- inference_runner.py +125 -0
inference_runner.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference_runner.py
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
# --- Import model architecture from the repository ---
|
| 11 |
+
from basicsr.archs.mair_arch import MaIR
|
| 12 |
+
from basicsr.utils.img_util import tensor2img
|
| 13 |
+
|
| 14 |
+
class MaIR_Upsampler:
|
| 15 |
+
"""
|
| 16 |
+
A self-contained class for the MaIR model for inference.
|
| 17 |
+
Handles model loading, pre-processing, and tiling for large images.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, model_name, device=None):
|
| 20 |
+
self.model_name = model_name
|
| 21 |
+
|
| 22 |
+
if device is None:
|
| 23 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 24 |
+
else:
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
print(f"Using device: {self.device} for model {self.model_name}")
|
| 28 |
+
|
| 29 |
+
self.MODEL_CONFIGS = self._get_model_configs()
|
| 30 |
+
|
| 31 |
+
if model_name not in self.MODEL_CONFIGS:
|
| 32 |
+
raise ValueError(f"Model '{model_name}' not recognized. Available: {list(self.MODEL_CONFIGS.keys())}")
|
| 33 |
+
|
| 34 |
+
self.model, self.scale = self._load_model()
|
| 35 |
+
self.model.eval()
|
| 36 |
+
self.model.to(self.device)
|
| 37 |
+
|
| 38 |
+
def _get_model_configs(self):
|
| 39 |
+
"""Returns a dictionary of all supported model configurations."""
|
| 40 |
+
mair_sr_base_params = {
|
| 41 |
+
'img_size': 64, 'patch_size': 1, 'in_chans': 3, 'embed_dim': 180,
|
| 42 |
+
'depths': (6, 6, 6, 6, 6, 6), 'drop_rate': 0., 'd_state': 16,
|
| 43 |
+
'ssm_ratio': 2.0, 'mlp_ratio': 2.5, 'drop_path_rate': 0.1,
|
| 44 |
+
'norm_layer': nn.LayerNorm, 'patch_norm': True, 'use_checkpoint': False,
|
| 45 |
+
'img_range': 1., 'upsampler': 'pixelshuffle', 'resi_connection': '1conv',
|
| 46 |
+
'dynamic_ids': True, 'scan_len': 4,
|
| 47 |
+
}
|
| 48 |
+
mair_cdn_base_params = mair_sr_base_params.copy()
|
| 49 |
+
mair_cdn_base_params.update({'upscale': 1, 'upsampler': ''})
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
'MaIR-SRx4': {'task': 'SR', 'scale': 4, 'filename': 'MaIR_SR_x4.pth', 'params': {**mair_sr_base_params, 'upscale': 4}},
|
| 53 |
+
'MaIR-SRx2': {'task': 'SR', 'scale': 2, 'filename': 'MaIR_SR_x2.pth', 'params': {**mair_sr_base_params, 'upscale': 2}},
|
| 54 |
+
'MaIR-CDN-s50': {'task': 'DN', 'scale': 1, 'filename': 'MaIR_CDN_s50.pth', 'params': mair_cdn_base_params},
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
def _load_model(self):
|
| 58 |
+
"""Loads the pretrained model weights from the local 'checkpoints' folder."""
|
| 59 |
+
config = self.MODEL_CONFIGS[self.model_name]
|
| 60 |
+
params = config['params']
|
| 61 |
+
filename = config['filename']
|
| 62 |
+
scale = config['scale']
|
| 63 |
+
|
| 64 |
+
model_path = os.path.join('checkpoints', filename)
|
| 65 |
+
if not os.path.exists(model_path):
|
| 66 |
+
raise FileNotFoundError(f"Checkpoint not found: {model_path}. Ensure it's in a 'checkpoints' folder.")
|
| 67 |
+
|
| 68 |
+
model = MaIR(**params)
|
| 69 |
+
load_net = torch.load(model_path, map_location=self.device)
|
| 70 |
+
param_key = 'params_ema' if 'params_ema' in load_net else 'params'
|
| 71 |
+
load_net = load_net[param_key]
|
| 72 |
+
|
| 73 |
+
for k, v in list(load_net.items()):
|
| 74 |
+
if k.startswith('module.'):
|
| 75 |
+
load_net[k[7:]] = v
|
| 76 |
+
del load_net[k]
|
| 77 |
+
|
| 78 |
+
model.load_state_dict(load_net, strict=True)
|
| 79 |
+
print(f"Model {self.model_name} loaded successfully from {model_path}.")
|
| 80 |
+
return model, scale
|
| 81 |
+
|
| 82 |
+
def _tile_inference(self, img_tensor):
|
| 83 |
+
"""Performs inference using a tiling strategy to handle large images."""
|
| 84 |
+
b, c, h, w = img_tensor.size()
|
| 85 |
+
tile_size, tile_pad = 200, 20
|
| 86 |
+
num_tiles_h = int(np.ceil(h / tile_size))
|
| 87 |
+
num_tiles_w = int(np.ceil(w / tile_size))
|
| 88 |
+
pad_h, pad_w = num_tiles_h * tile_size - h, num_tiles_w * tile_size - w
|
| 89 |
+
img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), 'reflect')
|
| 90 |
+
output_padded = F.interpolate(torch.zeros_like(img_padded), scale_factor=self.scale, mode='nearest')
|
| 91 |
+
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
for i in range(num_tiles_h):
|
| 94 |
+
for j in range(num_tiles_w):
|
| 95 |
+
h_start, h_end = i * tile_size, (i + 1) * tile_size
|
| 96 |
+
w_start, w_end = j * tile_size, (j + 1) * tile_size
|
| 97 |
+
h_start_pad, h_end_pad = max(0, h_start - tile_pad), min(img_padded.shape[2], h_end + tile_pad)
|
| 98 |
+
w_start_pad, w_end_pad = max(0, w_start - tile_pad), min(img_padded.shape[3], w_end + tile_pad)
|
| 99 |
+
|
| 100 |
+
tile_input = img_padded[:, :, h_start_pad:h_end_pad, w_start_pad:w_end_pad]
|
| 101 |
+
tile_output = self.model(tile_input)
|
| 102 |
+
|
| 103 |
+
out_h_start, out_h_end = h_start * self.scale, h_end * self.scale
|
| 104 |
+
out_w_start, out_w_end = w_start * self.scale, w_end * self.scale
|
| 105 |
+
cut_h_start = (h_start - h_start_pad) * self.scale
|
| 106 |
+
cut_h_end = cut_h_start + tile_size * self.scale
|
| 107 |
+
cut_w_start = (w_start - w_start_pad) * self.scale
|
| 108 |
+
cut_w_end = cut_w_start + tile_size * self.scale
|
| 109 |
+
|
| 110 |
+
output_padded[:, :, out_h_start:out_h_end, out_w_start:out_w_end] = tile_output[:, :, cut_h_start:cut_h_end, cut_w_start:cut_w_end]
|
| 111 |
+
|
| 112 |
+
return output_padded[:, :, :h * self.scale, :w * self.scale]
|
| 113 |
+
|
| 114 |
+
def process(self, img):
|
| 115 |
+
"""Main inference function."""
|
| 116 |
+
# Pre-processing
|
| 117 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 118 |
+
img_tensor = torch.from_numpy(img_rgb.transpose(2, 0, 1)).float() / 255.0
|
| 119 |
+
img_tensor = img_tensor.unsqueeze(0).to(self.device)
|
| 120 |
+
|
| 121 |
+
# Inference
|
| 122 |
+
output_tensor = self._tile_inference(img_tensor)
|
| 123 |
+
|
| 124 |
+
# Post-processing
|
| 125 |
+
return tensor2img(output_tensor, rgb2bgr=True)
|