import os import argparse import torch from resselt import load_from_file from pepeline import read, save, ImgColor, ImgFormat def parse_args(): parser = argparse.ArgumentParser( description="Batch image upscaling script" ) parser.add_argument("--input_dir", type=str, required=True, help="Path to input images") parser.add_argument("--output_dir", type=str, required=True, help="Path to save results") parser.add_argument("--weights", type=str, required=True, help="Path to model weights") parser.add_argument("--device", type=str, default=None, help="cuda or cpu") return parser.parse_args() def load_model(weights_path: str, device: torch.device): model = load_from_file(weights_path) model = model.to( device, memory_format=torch.preserve_format, non_blocking=True, ).eval() return model def process_image(model, img_path: str, device: torch.device): img = read(img_path, ImgColor.RGB, ImgFormat.F32).transpose(2, 0, 1) img = ( torch.tensor(img) .to( device, memory_format=torch.preserve_format, non_blocking=True, ) .unsqueeze(0) ) with torch.autocast(device.type, torch.float16): with torch.inference_mode(): output = model(img) output = output.permute(0, 2, 3, 1).detach().cpu().numpy()[0] return output def main(): args = parse_args() device = torch.device( args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu") ) os.makedirs(args.output_dir, exist_ok=True) model = load_model(args.weights, device) img_list = os.listdir(args.input_dir) total = len(img_list) for index, img_name in enumerate(img_list, start=1): print( f"\rProcessing {index}/{total} | {img_name}", end="", flush=True, ) img_path = os.path.join(args.input_dir, img_name) result = process_image(model, img_path, device) save(result.copy(), os.path.join(args.output_dir, img_name)) print("\nDone.") if __name__ == "__main__": main()