| | from imgen3flip import weights_path, Model, ImageBatch, OPTS |
| | import torch |
| | import torchvision as TV |
| | import torchvision.transforms.functional as VF |
| | import sys |
| |
|
| |
|
| | assert weights_path.exists(), "Model weights do not exist" |
| |
|
| | assert len(sys.argv) == 3, f"Usage: { |
| | sys.argv[0]} <input-filename> <output-filename>" |
| |
|
| | input_filename = sys.argv[1] |
| | output_filename = sys.argv[2] |
| |
|
| | assert input_filename != output_filename, f"Use different file names" |
| |
|
| | print("Loading the model") |
| | model = Model() |
| | model.load_state_dict(torch.load(weights_path)) |
| |
|
| | print(f"Loading 8x8 input image from {input_filename}") |
| | |
| | image = TV.io.read_image(input_filename)[:3] |
| | |
| | image = image / 255.0 |
| | assert image.shape[0] == 3, "RGB image expected" |
| | |
| | image = image.permute(1, 2, 0) |
| | |
| | |
| | image = image.view(1, 8, 8, 3) |
| |
|
| | |
| | |
| | dummy_target = torch.zeros(1, 64, 64, 3, **OPTS) |
| | dummy_loss = torch.tensor(-1, **OPTS) |
| | inference_batch = ImageBatch( |
| | im8=image.to(**OPTS), |
| | im64=dummy_target, |
| | loss=dummy_loss) |
| | result = model(inference_batch) |
| |
|
| | |
| | new_image = result.im64.detach().float().cpu() |
| | |
| | new_image = new_image[0] |
| | |
| | new_image = new_image.permute(2, 0, 1) |
| | assert new_image.shape == (3, 64, 64) |
| | img = VF.to_pil_image(new_image) |
| | |
| | print(f"Writing {img.height}x{img.width} image to {output_filename}") |
| | img.save(output_filename) |
| |
|