Spaces:
Running
Running
| import datetime | |
| import hashlib | |
| import numpy as np | |
| import os | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| import cv2 | |
| import gradio as gr | |
| from joblib import Parallel, delayed | |
| from numpy.typing import NDArray | |
| from PIL import Image | |
| def _run_in_subprocess(command: str, wd: str) -> Any: | |
| p = subprocess.Popen(command, shell=True, cwd=wd) | |
| (output, err) = p.communicate() | |
| p_status = p.wait() | |
| print("Status of subprocess: ", p_status) | |
| return p_status | |
| SWIN_IR_WD = "KAIR" | |
| SWINIR_CKPT_DIR: str = Path("KAIR/model_zoo/") | |
| MODEL_NAME_TO_PATH: Dict[str, Path] = { | |
| "LambdaSwinIR_v0.1": Path(str(SWINIR_CKPT_DIR) + "/805000_G.pth"), | |
| } | |
| SWINIR_NAME_TO_PATCH_SIZE: Dict[str, int] = { | |
| "LambdaSwinIR_v0.1": 96, | |
| } | |
| SWINIR_NAME_TO_SCALE: Dict[str, int] = { | |
| "LambdaSwinIR_v0.1": 2, | |
| } | |
| SWINIR_NAME_TO_LARGE_MODEL: Dict[str, bool] = { | |
| "LambdaSwinIR_v0.1": False, | |
| } | |
| def _run_swin_ir( | |
| image: NDArray, | |
| model_path: Path, | |
| patch_size: int, | |
| scale: int, | |
| is_large_model: bool, | |
| ): | |
| print("model_path: ", str(model_path)) | |
| m = hashlib.sha256() | |
| now_time = datetime.datetime.utcnow() | |
| m.update(bytes(str(model_path), encoding='utf-8') + | |
| bytes(now_time.strftime("%Y-%m-%d %H:%M:%S.%f"), encoding='utf-8')) | |
| random_id = m.hexdigest()[0:20] | |
| cwd = os.getcwd() | |
| input_root = Path(cwd + "/sr_interactive_tmp") | |
| input_root.mkdir(parents=True, exist_ok=True) | |
| Image.fromarray(image).save(str(input_root) + "/gradio_img.png") | |
| command = f"python main_test_swinir.py --scale {scale} " + \ | |
| f"--folder_lq {input_root} --task real_sr " + \ | |
| f"--model_path {cwd}/{model_path} --training_patch_size {patch_size}" | |
| if is_large_model: | |
| command += " --large_model" | |
| print("COMMAND: ", command) | |
| status = _run_in_subprocess(command, wd=cwd + "/" + SWIN_IR_WD) | |
| print("STATUS: ", status) | |
| if scale == 2: | |
| str_scale = "2" | |
| if scale == 4: | |
| str_scale = "4_large" | |
| output_img = Image.open(f"{cwd}/KAIR/results/swinir_real_sr_x{str_scale}/gradio_img_SwinIR.png") | |
| output_root = Path("./sr_interactive_tmp_output") | |
| output_root.mkdir(parents=True, exist_ok=True) | |
| output_img.save(str(output_root) + "/SwinIR_" + random_id + ".png") | |
| print("SAVING: SwinIR_" + random_id + ".png") | |
| result = np.array(output_img) | |
| return result | |
| def _bilinear_upsample(image: NDArray): | |
| result = cv2.resize( | |
| image, | |
| dsize=(image.shape[1] * 2, image.shape[0] * 2), | |
| interpolation=cv2.INTER_LANCZOS4 | |
| ) | |
| return result | |
| def _decide_sr_algo(model_name: str, image: NDArray): | |
| # if "SwinIR" in model_name: | |
| # result = _run_swin_ir(image, | |
| # model_path=MODEL_NAME_TO_PATH[model_name], | |
| # patch_size=SWINIR_NAME_TO_PATCH_SIZE[model_name], | |
| # scale=SWINIR_NAME_TO_SCALE[model_name], | |
| # is_large_model=("SwinIR-L" in model_name)) | |
| # else: | |
| # result = _bilinear_upsample(image) | |
| # elif algo == SR_OPTIONS[1]: | |
| # result = _run_maxine(image, mode="SR") | |
| # elif algo == SR_OPTIONS[2]: | |
| # result = _run_maxine(image, mode="UPSCALE") | |
| # return result | |
| result = _run_swin_ir(image, | |
| model_path=MODEL_NAME_TO_PATH[model_name], | |
| patch_size=SWINIR_NAME_TO_PATCH_SIZE[model_name], | |
| scale=SWINIR_NAME_TO_SCALE[model_name], | |
| is_large_model=SWINIR_NAME_TO_LARGE_MODEL[model_name]) | |
| return result | |
| def _super_resolve(model_name: str, input_img): | |
| # futures = [] | |
| # with ThreadPoolExecutor(max_workers=4) as executor: | |
| # for model_name in model_names: | |
| # futures.append(executor.submit(_decide_sr_algo, model_name, input_img)) | |
| # return [f.result() for f in futures] | |
| # return Parallel(n_jobs=2, prefer="threads")( | |
| # delayed(_decide_sr_algo)(model_name, input_img) | |
| # for model_name in model_names | |
| # ) | |
| return _decide_sr_algo(model_name, input_img) | |
| def _gradio_handler(sr_option: str, input_img: NDArray): | |
| return _super_resolve(sr_option, input_img) | |
| gr.close_all() | |
| SR_OPTIONS = ["LambdaSwinIR_v0.1"] | |
| examples = [ | |
| ["LambdaSwinIR_v0.1", "examples/oldphoto6.png"], | |
| ["LambdaSwinIR_v0.1", "examples/Lincoln.png"], | |
| ["LambdaSwinIR_v0.1", "examples/OST_009.png"], | |
| ["LambdaSwinIR_v0.1", "examples/00003.png"], | |
| ["LambdaSwinIR_v0.1", "examples/00000067_cropped.png"], | |
| ] | |
| ui = gr.Interface(fn=_gradio_handler, | |
| inputs=[ | |
| gr.Radio(SR_OPTIONS), | |
| gr.Image(image_mode="RGB") | |
| ], | |
| outputs=["image"], | |
| live=False, | |
| examples=examples, | |
| cache_examples=True) | |
| ui.launch(enable_queue=True) | |