figsr / inference.py
umzi's picture
Upload folder using huggingface_hub
4f763cc verified
import os
import argparse
import torch
from resselt import load_from_file
from pepeline import read, save, ImgColor, ImgFormat
def parse_args():
parser = argparse.ArgumentParser(
description="Batch image upscaling script"
)
parser.add_argument("--input_dir", type=str, required=True, help="Path to input images")
parser.add_argument("--output_dir", type=str, required=True, help="Path to save results")
parser.add_argument("--weights", type=str, required=True, help="Path to model weights")
parser.add_argument("--device", type=str, default=None, help="cuda or cpu")
return parser.parse_args()
def load_model(weights_path: str, device: torch.device):
model = load_from_file(weights_path)
model = model.to(
device,
memory_format=torch.preserve_format,
non_blocking=True,
).eval()
return model
def process_image(model, img_path: str, device: torch.device):
img = read(img_path, ImgColor.RGB, ImgFormat.F32).transpose(2, 0, 1)
img = (
torch.tensor(img)
.to(
device,
memory_format=torch.preserve_format,
non_blocking=True,
)
.unsqueeze(0)
)
with torch.autocast(device.type, torch.float16):
with torch.inference_mode():
output = model(img)
output = output.permute(0, 2, 3, 1).detach().cpu().numpy()[0]
return output
def main():
args = parse_args()
device = torch.device(
args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
)
os.makedirs(args.output_dir, exist_ok=True)
model = load_model(args.weights, device)
img_list = os.listdir(args.input_dir)
total = len(img_list)
for index, img_name in enumerate(img_list, start=1):
print(
f"\rProcessing {index}/{total} | {img_name}",
end="",
flush=True,
)
img_path = os.path.join(args.input_dir, img_name)
result = process_image(model, img_path, device)
save(result.copy(), os.path.join(args.output_dir, img_name))
print("\nDone.")
if __name__ == "__main__":
main()