JunhaoYu's picture
Upload 10 files
96fb300 verified
import torch
import hydra
from omegaconf import DictConfig, OmegaConf
from model.colorization import colorization
from model.super_resolution import super_res
@hydra.main(config_path="cfgs/", config_name="demo")
def main(cfg: DictConfig):
OmegaConf.set_struct(cfg, False)
# Print configuration
print("Model Config:", OmegaConf.to_yaml(cfg))
# Configure CUDA settings
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
# Set seed for reproducibility
seed_all_random_engines(cfg.seed)
color_img_path, gray_image_path = colorization(image_path=cfg.image_path,
output_dir=cfg.output_dir,
ckpt_dir=cfg.ckpt_dir)
print(f"Complete colorization and saved to {color_img_path}.")
super_image_path = super_res(input_path=color_img_path,
output_dir=cfg.output_dir,
ckpt_dir=cfg.ckpt_dir)
print(f"Complete super resolution and saved to {super_image_path}.")
if __name__ == "__main__":
main()