Spaces:
Sleeping
Sleeping
| import datetime | |
| import math | |
| import pickle | |
| import time | |
| from pathlib import Path | |
| import cv2 | |
| import httpx | |
| import numpy as np | |
| from func_timeout import func_set_timeout | |
| from func_timeout.exceptions import FunctionTimedOut | |
| from logger import logger | |
| import common | |
| from config import settings | |
| from onnx_infer import OnnxSRInfer | |
| def _process_image( | |
| model: common.ModelInfo = common.models[common.MODEL_NAME_DEFAULT], | |
| tile_size: int = 64, # 分块大小 | |
| scale: int = 4, # 放大倍数 | |
| skip_alpha: bool = False, # 是否跳过alpha通道 | |
| resize_to: str = None, # 调整大小 两种格式: 1. 1920x1080 2. 1/2 | |
| input_image: Path = None, | |
| output_path: Path | str = settings.get("output_dir", "output"), | |
| gpuid: int = 0, | |
| clean: bool = True, | |
| ) -> Path: | |
| logger.info(f"processing image: {input_image}") | |
| start_time = datetime.datetime.now() | |
| try: | |
| provider_options = None | |
| if int(gpuid) >= 0: | |
| provider_options = [{"device_id": int(gpuid)}] | |
| sr_instance = OnnxSRInfer( | |
| model.path, | |
| model.scale, | |
| model.name, | |
| providers=[settings.get("provider", "CPUExecutionProvider")], | |
| provider_options=provider_options, | |
| ) | |
| if skip_alpha: | |
| logger.debug("Skip Alpha Channel") | |
| sr_instance.alpha_upsampler = "interpolation" | |
| logger.debug(f"decoding image: {input_image}") | |
| img = cv2.imdecode( | |
| np.fromfile(input_image, dtype=np.uint8), cv2.IMREAD_UNCHANGED | |
| ) | |
| h, w, _ = img.shape | |
| sr_img = sr_instance.universal_process_pipeline(img, tile_size=tile_size) | |
| scale = int(scale) | |
| target_h = None | |
| target_w = None | |
| if scale > model.scale and model.scale != 1: | |
| logger.debug("re process") | |
| # calc process times | |
| scale_log = math.log(scale, model.scale) | |
| total_times = math.ceil(scale_log) | |
| # calc target size | |
| if total_times != int(scale_log): | |
| target_h = h * scale | |
| target_w = w * scale | |
| for _ in range(total_times - 1): | |
| sr_img = sr_instance.universal_process_pipeline( | |
| sr_img, tile_size=tile_size | |
| ) | |
| elif scale < model.scale: | |
| logger.debug("down scale") | |
| target_h = h * scale | |
| target_w = w * scale | |
| if resize_to: | |
| logger.debug(f"resize to {resize_to}") | |
| if "x" in resize_to: | |
| param_w = int(resize_to.split("x")[0]) | |
| target_w = param_w | |
| target_h = int(h * param_w / w) | |
| elif "/" in resize_to: | |
| ratio = int(resize_to.split("/")[0]) / int(resize_to.split("/")[1]) | |
| target_w = int(w * ratio) | |
| target_h = int(h * ratio) | |
| if target_w: | |
| logger.debug(f"resize to {target_w}x{target_h}") | |
| img_out = cv2.resize(sr_img, (target_w, target_h)) | |
| else: | |
| img_out = sr_img | |
| # save | |
| final_output_path = Path(output_path) / f"{input_image.stem}_{model.name}.png" | |
| if not Path(output_path).exists(): | |
| Path(output_path).mkdir(parents=True) | |
| cv2.imencode(".png", img_out)[1].tofile(final_output_path) | |
| return final_output_path | |
| except Exception as e: | |
| logger.error(f"process image error: {e}") | |
| return None | |
| finally: | |
| logger.info( | |
| f"Time taken: {(datetime.datetime.now() - start_time).seconds} seconds to process {input_image}" | |
| ) | |
| if clean and input_image.exists(): | |
| input_image.unlink() | |
| def listen_queue( | |
| stream_name: str = common.BASE_STREAM_NAME, | |
| default_timeout: int = common.PROGRESS_TIMEOUT, | |
| ): | |
| logger.info(f"Listening to stream: {stream_name}") | |
| last_id = "0" | |
| while True: | |
| messages = common.redis_client.xread({stream_name: last_id}, count=1, block=0) | |
| if not messages: | |
| continue | |
| message_id = messages[0][1][0][0] | |
| last_id = message_id | |
| message = messages[0][1][0][1] | |
| logger.info(f"Processing task: {message_id.decode('utf-8')}") | |
| data: dict[str, Path | int | bool | str | None] = pickle.loads(message[b"data"]) | |
| input_image = data.get("input_image") | |
| tile_size = data.get("tile_size", 64) | |
| scale = data.get("scale", 4) | |
| skip_alpha = data.get("skip_alpha", False) | |
| resize_to = data.get("resize_to", None) | |
| time_out = data.get("timeout", default_timeout) | |
| model_name = data.get("model", common.MODEL_NAME_DEFAULT) | |
| common.redis_client.set( | |
| f"{common.RESULT_KEY_PREFIX}{message_id.decode('utf-8')}", | |
| pickle.dumps({"status": "processing"}), | |
| ex=86400, | |
| ) | |
| processed_path: Path | None = None | |
| try: | |
| processed_path = _process_image( | |
| model=common.models[model_name], | |
| input_image=input_image, | |
| tile_size=tile_size, | |
| scale=scale, | |
| skip_alpha=skip_alpha, | |
| resize_to=resize_to, | |
| forceTimeout=time_out, | |
| ) | |
| except FunctionTimedOut as e: | |
| logger.warning(e) | |
| processed_path = None | |
| if processed_path: | |
| common.redis_client.set( | |
| f"{common.RESULT_KEY_PREFIX}{message_id.decode('utf-8')}", | |
| pickle.dumps( | |
| { | |
| "status": "success", | |
| "path": processed_path.as_posix(), | |
| "size": processed_path.stat().st_size, | |
| } | |
| ), | |
| ex=86400, | |
| ) | |
| logger.success(f"Processed image: {processed_path}") | |
| else: | |
| common.redis_client.set( | |
| f"{common.RESULT_KEY_PREFIX}{message_id.decode('utf-8')}", | |
| pickle.dumps({"status": "failed"}), | |
| ex=86400, | |
| ) | |
| common.redis_client.xdel(stream_name, message_id) | |
| for file in Path(settings.get("output_dir", "output")).iterdir(): | |
| if datetime.datetime.now().timestamp() - file.stat().st_mtime > 86400: | |
| file.unlink() | |
| def listen_distributed_queue(stream_name: str = common.DISTRIBUTED_STREAM_NAME): | |
| logger.info(f"Listening to distributed stream: {stream_name}") | |
| last_id = "0" | |
| while True: | |
| messages = common.redis_client.xread({stream_name: last_id}, count=1, block=0) | |
| if not messages: | |
| continue | |
| task_id = messages[0][1][0][0] | |
| last_id = task_id | |
| message = messages[0][1][0][1] | |
| logger.info(f"Processing task: {task_id.decode('utf-8')}") | |
| time_start = datetime.datetime.now() | |
| data: dict = pickle.loads(message[b"data"]) | |
| worker_response: dict = data.get("worker_response") | |
| input_image = data.get("input_image") | |
| input_image: Path | |
| scale: int = data.get("scale", 4) | |
| common.redis_client.set( | |
| f"{common.RESULT_KEY_PREFIX}{task_id.decode('utf-8')}", | |
| pickle.dumps({"status": "processing"}), | |
| ex=86400, | |
| ) | |
| original_w, original_h = common.get_image_size(input_image) | |
| ok_keys = [] | |
| scaled_tiles: list[common.TileInfo] = [] | |
| while True: | |
| try: | |
| for worker_key, worker_data in worker_response.items(): | |
| logger.debug(f"Checking worker: {worker_key.decode('utf-8')}") | |
| worker = common.redis_client.get(worker_key) | |
| if not worker: | |
| raise Exception(f"Worker {worker_key.decode('utf-8')} offline") | |
| worker_url, token = worker.decode("utf-8").split("|") | |
| worker_task_id = worker_data["task_id"] | |
| response = httpx.get( | |
| f"{worker_url}/result/{worker_task_id}", | |
| headers={"X-Token": token}, | |
| ) | |
| if response.status_code != 200: | |
| raise Exception( | |
| f"Worker {worker_key.decode('utf-8')} get task status failed" | |
| ) | |
| result = response.json()["result"] | |
| if result["status"] == "failed": | |
| raise Exception( | |
| f"Worker {worker_key.decode('utf-8')} processing failed" | |
| ) | |
| if result["status"] == "success": | |
| logger.info(f"Worker {worker_key.decode('utf-8')} processed") | |
| response = httpx.get( | |
| f"{worker_url}/result/{worker_task_id}/download", | |
| headers={"X-Token": token}, | |
| ) | |
| if response.status_code != 200: | |
| raise Exception( | |
| f"Worker {worker_key.decode('utf-8')} download failed" | |
| ) | |
| tile_info: common.TileInfo = worker_data["tile_info"] | |
| file_path = ( | |
| Path(settings.get("output_dir", "output")) | |
| / f"{input_image.stem}" | |
| / f"{input_image.stem}_scaled_{tile_info.y}_{tile_info.x}.png" | |
| ) | |
| with open(file_path, "wb") as f: | |
| f.write(response.content) | |
| logger.debug(f"Downloaded tile: {file_path}") | |
| scaled_tiles.append( | |
| common.TileInfo(tile_info.x, tile_info.y, file_path) | |
| ) | |
| ok_keys.append(worker_key) | |
| for key in ok_keys: | |
| worker_response.pop(key, None) | |
| if not worker_response: | |
| logger.info( | |
| f"All workers processed, start merge {len(scaled_tiles)} tiles" | |
| ) | |
| output_path = ( | |
| Path(settings.get("output_dir", "output")) | |
| / f"{input_image.stem}" | |
| / f"{input_image.stem}_scaled_x{scale}.png" | |
| ) | |
| common.merge_sr_tiles( | |
| scaled_tiles, | |
| output_path, | |
| (original_w, original_h), | |
| scale, | |
| ) | |
| logger.success( | |
| f"Processed image: {output_path}, time taken: {(datetime.datetime.now() - time_start).seconds} seconds" | |
| ) | |
| common.redis_client.set( | |
| f"{common.RESULT_KEY_PREFIX}{task_id.decode('utf-8')}", | |
| pickle.dumps( | |
| { | |
| "status": "success", | |
| "path": output_path.as_posix(), | |
| "size": output_path.stat().st_size, | |
| } | |
| ), | |
| ex=86400, | |
| ) | |
| break | |
| time.sleep(settings.get("worker_check_interval", 5)) | |
| except Exception as e: | |
| logger.error(f"{e.__class__.__name__}: {e}") | |
| common.redis_client.set( | |
| f"{common.RESULT_KEY_PREFIX}{task_id.decode('utf-8')}", | |
| pickle.dumps({"status": "failed"}), | |
| ex=86400, | |
| ) | |
| break | |