P-rateek commited on
Commit
5df9861
·
verified ·
1 Parent(s): 688044a

Upload inference_runner.py

Browse files
Files changed (1) hide show
  1. inference_runner.py +125 -0
inference_runner.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_runner.py
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from tqdm import tqdm
9
+
10
+ # --- Import model architecture from the repository ---
11
+ from basicsr.archs.mair_arch import MaIR
12
+ from basicsr.utils.img_util import tensor2img
13
+
14
+ class MaIR_Upsampler:
15
+ """
16
+ A self-contained class for the MaIR model for inference.
17
+ Handles model loading, pre-processing, and tiling for large images.
18
+ """
19
+ def __init__(self, model_name, device=None):
20
+ self.model_name = model_name
21
+
22
+ if device is None:
23
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ else:
25
+ self.device = device
26
+
27
+ print(f"Using device: {self.device} for model {self.model_name}")
28
+
29
+ self.MODEL_CONFIGS = self._get_model_configs()
30
+
31
+ if model_name not in self.MODEL_CONFIGS:
32
+ raise ValueError(f"Model '{model_name}' not recognized. Available: {list(self.MODEL_CONFIGS.keys())}")
33
+
34
+ self.model, self.scale = self._load_model()
35
+ self.model.eval()
36
+ self.model.to(self.device)
37
+
38
+ def _get_model_configs(self):
39
+ """Returns a dictionary of all supported model configurations."""
40
+ mair_sr_base_params = {
41
+ 'img_size': 64, 'patch_size': 1, 'in_chans': 3, 'embed_dim': 180,
42
+ 'depths': (6, 6, 6, 6, 6, 6), 'drop_rate': 0., 'd_state': 16,
43
+ 'ssm_ratio': 2.0, 'mlp_ratio': 2.5, 'drop_path_rate': 0.1,
44
+ 'norm_layer': nn.LayerNorm, 'patch_norm': True, 'use_checkpoint': False,
45
+ 'img_range': 1., 'upsampler': 'pixelshuffle', 'resi_connection': '1conv',
46
+ 'dynamic_ids': True, 'scan_len': 4,
47
+ }
48
+ mair_cdn_base_params = mair_sr_base_params.copy()
49
+ mair_cdn_base_params.update({'upscale': 1, 'upsampler': ''})
50
+
51
+ return {
52
+ 'MaIR-SRx4': {'task': 'SR', 'scale': 4, 'filename': 'MaIR_SR_x4.pth', 'params': {**mair_sr_base_params, 'upscale': 4}},
53
+ 'MaIR-SRx2': {'task': 'SR', 'scale': 2, 'filename': 'MaIR_SR_x2.pth', 'params': {**mair_sr_base_params, 'upscale': 2}},
54
+ 'MaIR-CDN-s50': {'task': 'DN', 'scale': 1, 'filename': 'MaIR_CDN_s50.pth', 'params': mair_cdn_base_params},
55
+ }
56
+
57
+ def _load_model(self):
58
+ """Loads the pretrained model weights from the local 'checkpoints' folder."""
59
+ config = self.MODEL_CONFIGS[self.model_name]
60
+ params = config['params']
61
+ filename = config['filename']
62
+ scale = config['scale']
63
+
64
+ model_path = os.path.join('checkpoints', filename)
65
+ if not os.path.exists(model_path):
66
+ raise FileNotFoundError(f"Checkpoint not found: {model_path}. Ensure it's in a 'checkpoints' folder.")
67
+
68
+ model = MaIR(**params)
69
+ load_net = torch.load(model_path, map_location=self.device)
70
+ param_key = 'params_ema' if 'params_ema' in load_net else 'params'
71
+ load_net = load_net[param_key]
72
+
73
+ for k, v in list(load_net.items()):
74
+ if k.startswith('module.'):
75
+ load_net[k[7:]] = v
76
+ del load_net[k]
77
+
78
+ model.load_state_dict(load_net, strict=True)
79
+ print(f"Model {self.model_name} loaded successfully from {model_path}.")
80
+ return model, scale
81
+
82
+ def _tile_inference(self, img_tensor):
83
+ """Performs inference using a tiling strategy to handle large images."""
84
+ b, c, h, w = img_tensor.size()
85
+ tile_size, tile_pad = 200, 20
86
+ num_tiles_h = int(np.ceil(h / tile_size))
87
+ num_tiles_w = int(np.ceil(w / tile_size))
88
+ pad_h, pad_w = num_tiles_h * tile_size - h, num_tiles_w * tile_size - w
89
+ img_padded = F.pad(img_tensor, (0, pad_w, 0, pad_h), 'reflect')
90
+ output_padded = F.interpolate(torch.zeros_like(img_padded), scale_factor=self.scale, mode='nearest')
91
+
92
+ with torch.no_grad():
93
+ for i in range(num_tiles_h):
94
+ for j in range(num_tiles_w):
95
+ h_start, h_end = i * tile_size, (i + 1) * tile_size
96
+ w_start, w_end = j * tile_size, (j + 1) * tile_size
97
+ h_start_pad, h_end_pad = max(0, h_start - tile_pad), min(img_padded.shape[2], h_end + tile_pad)
98
+ w_start_pad, w_end_pad = max(0, w_start - tile_pad), min(img_padded.shape[3], w_end + tile_pad)
99
+
100
+ tile_input = img_padded[:, :, h_start_pad:h_end_pad, w_start_pad:w_end_pad]
101
+ tile_output = self.model(tile_input)
102
+
103
+ out_h_start, out_h_end = h_start * self.scale, h_end * self.scale
104
+ out_w_start, out_w_end = w_start * self.scale, w_end * self.scale
105
+ cut_h_start = (h_start - h_start_pad) * self.scale
106
+ cut_h_end = cut_h_start + tile_size * self.scale
107
+ cut_w_start = (w_start - w_start_pad) * self.scale
108
+ cut_w_end = cut_w_start + tile_size * self.scale
109
+
110
+ output_padded[:, :, out_h_start:out_h_end, out_w_start:out_w_end] = tile_output[:, :, cut_h_start:cut_h_end, cut_w_start:cut_w_end]
111
+
112
+ return output_padded[:, :, :h * self.scale, :w * self.scale]
113
+
114
+ def process(self, img):
115
+ """Main inference function."""
116
+ # Pre-processing
117
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
118
+ img_tensor = torch.from_numpy(img_rgb.transpose(2, 0, 1)).float() / 255.0
119
+ img_tensor = img_tensor.unsqueeze(0).to(self.device)
120
+
121
+ # Inference
122
+ output_tensor = self._tile_inference(img_tensor)
123
+
124
+ # Post-processing
125
+ return tensor2img(output_tensor, rgb2bgr=True)