LYL1015's picture
test
eea83e8
import torch
from torchvision.utils import save_image
import torchvision.transforms as tfs
from PIL import Image
import yaml
import os
from tqdm import tqdm
from .UDR_S2Former import Transformer
def load_s2former_model(model_path, device):
"""
Load S2Former model weights
Args:
model_path (str): Path to model weights
device (torch.device): Device to load model on
Returns:
torch.nn.Module: Loaded S2Former model
"""
path = os.path.join(model_path, 'udrs2former_demo.pth')
# Assuming the model input size is fixed at 320x320
model = Transformer((320, 320)).to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval() # Set model to evaluation mode
return model
def s2former_predict(model, input_img, output_dir, device):
"""
Perform S2Former prediction on a single image
Args:
model (torch.nn.Module): Loaded S2Former model
input_img (str or PIL.Image.Image): Input image path or PIL Image
output_dir (str): Directory to save output image
device (torch.device): Device to run inference on
Returns:
torch.Tensor: Processed output image
"""
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Get image name for saving
img_name, ext = os.path.splitext(os.path.basename(input_img))
# Load image if path is provided
if isinstance(input_img, str):
input_img = Image.open(input_img).convert('RGB')
# Convert to tensor and add batch dimension
img_tensor = tfs.ToTensor()(input_img).unsqueeze(0).to(device)
# Get image dimensions
b, c, h, w = img_tensor.shape
# Inference parameters
tile = min(320, h, w)
tile_overlap = 64
sf = 1 # scale factor
# Prepare output tensors
stride = tile - tile_overlap
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
E1 = torch.zeros(b, c, h*sf, w*sf).type_as(img_tensor)
W1 = torch.zeros_like(E1)
# Tile-based inference
with torch.no_grad():
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img_tensor[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
out_patch1, _ = model(in_patch)
out_patch1 = out_patch1[0]
out_patch_mask1 = torch.ones_like(out_patch1)
E1[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch1)
W1[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask1)
# Normalize output
output = E1.div_(W1)
# Save output image
save_path = os.path.join(output_dir, f'{img_name}.png')
save_image(output, save_path, normalize=False)
return save_path
# Example usage
if __name__ == "__main__":
# Set up device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model
model_path = 'weights/udrs2former_demo.pth'
model = load_s2former_model(model_path, device)
# Predict on an image
input_img = 'assests/1.jpg'
output_dir = 'newoutput'
s2former_predict(model, input_img, output_dir, device)