Spaces:
Sleeping
Sleeping
File size: 3,592 Bytes
e340a84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | import os
import copy
import cv2
import numpy as np
import shutil
import urllib.request
try:
import onnxruntime
except Exception:
onnxruntime = None
SKYSEG_URL = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
SKYSEG_THRESHOLD = 0.5
def run_skyseg(session, input_size, image):
temp_image = copy.deepcopy(image)
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
x = np.array(x, dtype=np.float32)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x = (x / 255 - mean) / std
x = x.transpose(2, 0, 1)
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
input_name = session.get_inputs()[0].name
result_map = session.run(None, {input_name: x})[0]
return result_map[0, 0]
def _normalize_skyseg_output(result_map):
result_map = np.asarray(result_map, dtype=np.float32)
if result_map.size == 0:
return result_map
finite = np.isfinite(result_map)
if not np.any(finite):
return np.zeros_like(result_map, dtype=np.float32)
result_map = np.nan_to_num(result_map, nan=0.0, posinf=1.0, neginf=0.0)
max_value = float(result_map.max())
min_value = float(result_map.min())
if min_value >= 0.0 and max_value > 1.5:
result_map = result_map / 255.0
return np.clip(result_map, 0.0, 1.0)
def sky_mask_filename(image_path):
parent = os.path.basename(os.path.dirname(image_path))
name = os.path.basename(image_path)
if parent:
return f"{parent}__{name}"
return name
def segment_sky(image_path, session, mask_filename=None):
image = cv2.imread(image_path)
if image is None:
return None
result_map = run_skyseg(session, [320, 320], image)
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
result_map_original = _normalize_skyseg_output(result_map_original)
output_mask = np.zeros(result_map_original.shape, dtype=np.uint8)
output_mask[result_map_original < SKYSEG_THRESHOLD] = 255
if mask_filename is not None:
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
cv2.imwrite(mask_filename, output_mask)
return output_mask
def compute_sky_mask(image_paths, model_path: str, target_dir: str = None):
if onnxruntime is None:
return None
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(os.path.abspath(model_path)), exist_ok=True)
try:
print(f"[longstream] downloading skyseg.onnx to {model_path}", flush=True)
with urllib.request.urlopen(SKYSEG_URL) as src, open(
model_path, "wb"
) as dst:
shutil.copyfileobj(src, dst)
except Exception as exc:
print(f"[longstream] failed to download skyseg.onnx: {exc}", flush=True)
return None
if not os.path.exists(model_path):
return None
session = onnxruntime.InferenceSession(model_path)
masks = []
for image_path in image_paths:
mask_filepath = None
if target_dir is not None:
name = sky_mask_filename(image_path)
mask_filepath = os.path.join(target_dir, name)
if os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
else:
sky_mask = segment_sky(image_path, session, mask_filepath)
else:
sky_mask = segment_sky(image_path, session, None)
masks.append(sky_mask)
return masks
|