import os import copy import cv2 import numpy as np import shutil import urllib.request try: import onnxruntime except Exception: onnxruntime = None SKYSEG_URL = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx" SKYSEG_THRESHOLD = 0.5 def run_skyseg(session, input_size, image): temp_image = copy.deepcopy(image) resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) x = np.array(x, dtype=np.float32) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] x = (x / 255 - mean) / std x = x.transpose(2, 0, 1) x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") input_name = session.get_inputs()[0].name result_map = session.run(None, {input_name: x})[0] return result_map[0, 0] def _normalize_skyseg_output(result_map): result_map = np.asarray(result_map, dtype=np.float32) if result_map.size == 0: return result_map finite = np.isfinite(result_map) if not np.any(finite): return np.zeros_like(result_map, dtype=np.float32) result_map = np.nan_to_num(result_map, nan=0.0, posinf=1.0, neginf=0.0) max_value = float(result_map.max()) min_value = float(result_map.min()) if min_value >= 0.0 and max_value > 1.5: result_map = result_map / 255.0 return np.clip(result_map, 0.0, 1.0) def sky_mask_filename(image_path): parent = os.path.basename(os.path.dirname(image_path)) name = os.path.basename(image_path) if parent: return f"{parent}__{name}" return name def segment_sky(image_path, session, mask_filename=None): image = cv2.imread(image_path) if image is None: return None result_map = run_skyseg(session, [320, 320], image) result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) result_map_original = _normalize_skyseg_output(result_map_original) output_mask = np.zeros(result_map_original.shape, dtype=np.uint8) output_mask[result_map_original < SKYSEG_THRESHOLD] = 255 if mask_filename is not None: os.makedirs(os.path.dirname(mask_filename), exist_ok=True) cv2.imwrite(mask_filename, output_mask) return output_mask def compute_sky_mask(image_paths, model_path: str, target_dir: str = None): if onnxruntime is None: return None if not os.path.exists(model_path): os.makedirs(os.path.dirname(os.path.abspath(model_path)), exist_ok=True) try: print(f"[longstream] downloading skyseg.onnx to {model_path}", flush=True) with urllib.request.urlopen(SKYSEG_URL) as src, open( model_path, "wb" ) as dst: shutil.copyfileobj(src, dst) except Exception as exc: print(f"[longstream] failed to download skyseg.onnx: {exc}", flush=True) return None if not os.path.exists(model_path): return None session = onnxruntime.InferenceSession(model_path) masks = [] for image_path in image_paths: mask_filepath = None if target_dir is not None: name = sky_mask_filename(image_path) mask_filepath = os.path.join(target_dir, name) if os.path.exists(mask_filepath): sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) else: sky_mask = segment_sky(image_path, session, mask_filepath) else: sky_mask = segment_sky(image_path, session, None) masks.append(sky_mask) return masks