LongStream / longstream /utils /sky_mask.py
Cc
init
e340a84
import os
import copy
import cv2
import numpy as np
import shutil
import urllib.request
try:
import onnxruntime
except Exception:
onnxruntime = None
SKYSEG_URL = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
SKYSEG_THRESHOLD = 0.5
def run_skyseg(session, input_size, image):
temp_image = copy.deepcopy(image)
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
x = np.array(x, dtype=np.float32)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x = (x / 255 - mean) / std
x = x.transpose(2, 0, 1)
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
input_name = session.get_inputs()[0].name
result_map = session.run(None, {input_name: x})[0]
return result_map[0, 0]
def _normalize_skyseg_output(result_map):
result_map = np.asarray(result_map, dtype=np.float32)
if result_map.size == 0:
return result_map
finite = np.isfinite(result_map)
if not np.any(finite):
return np.zeros_like(result_map, dtype=np.float32)
result_map = np.nan_to_num(result_map, nan=0.0, posinf=1.0, neginf=0.0)
max_value = float(result_map.max())
min_value = float(result_map.min())
if min_value >= 0.0 and max_value > 1.5:
result_map = result_map / 255.0
return np.clip(result_map, 0.0, 1.0)
def sky_mask_filename(image_path):
parent = os.path.basename(os.path.dirname(image_path))
name = os.path.basename(image_path)
if parent:
return f"{parent}__{name}"
return name
def segment_sky(image_path, session, mask_filename=None):
image = cv2.imread(image_path)
if image is None:
return None
result_map = run_skyseg(session, [320, 320], image)
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
result_map_original = _normalize_skyseg_output(result_map_original)
output_mask = np.zeros(result_map_original.shape, dtype=np.uint8)
output_mask[result_map_original < SKYSEG_THRESHOLD] = 255
if mask_filename is not None:
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
cv2.imwrite(mask_filename, output_mask)
return output_mask
def compute_sky_mask(image_paths, model_path: str, target_dir: str = None):
if onnxruntime is None:
return None
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(os.path.abspath(model_path)), exist_ok=True)
try:
print(f"[longstream] downloading skyseg.onnx to {model_path}", flush=True)
with urllib.request.urlopen(SKYSEG_URL) as src, open(
model_path, "wb"
) as dst:
shutil.copyfileobj(src, dst)
except Exception as exc:
print(f"[longstream] failed to download skyseg.onnx: {exc}", flush=True)
return None
if not os.path.exists(model_path):
return None
session = onnxruntime.InferenceSession(model_path)
masks = []
for image_path in image_paths:
mask_filepath = None
if target_dir is not None:
name = sky_mask_filename(image_path)
mask_filepath = os.path.join(target_dir, name)
if os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
else:
sky_mask = segment_sky(image_path, session, mask_filepath)
else:
sky_mask = segment_sky(image_path, session, None)
masks.append(sky_mask)
return masks