File size: 3,592 Bytes
e340a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import copy
import cv2
import numpy as np
import shutil
import urllib.request

try:
    import onnxruntime
except Exception:
    onnxruntime = None

SKYSEG_URL = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
SKYSEG_THRESHOLD = 0.5


def run_skyseg(session, input_size, image):
    temp_image = copy.deepcopy(image)
    resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
    x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
    x = np.array(x, dtype=np.float32)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    x = (x / 255 - mean) / std
    x = x.transpose(2, 0, 1)
    x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
    input_name = session.get_inputs()[0].name
    result_map = session.run(None, {input_name: x})[0]
    return result_map[0, 0]


def _normalize_skyseg_output(result_map):
    result_map = np.asarray(result_map, dtype=np.float32)
    if result_map.size == 0:
        return result_map
    finite = np.isfinite(result_map)
    if not np.any(finite):
        return np.zeros_like(result_map, dtype=np.float32)
    result_map = np.nan_to_num(result_map, nan=0.0, posinf=1.0, neginf=0.0)
    max_value = float(result_map.max())
    min_value = float(result_map.min())
    if min_value >= 0.0 and max_value > 1.5:
        result_map = result_map / 255.0
    return np.clip(result_map, 0.0, 1.0)


def sky_mask_filename(image_path):
    parent = os.path.basename(os.path.dirname(image_path))
    name = os.path.basename(image_path)
    if parent:
        return f"{parent}__{name}"
    return name


def segment_sky(image_path, session, mask_filename=None):
    image = cv2.imread(image_path)
    if image is None:
        return None
    result_map = run_skyseg(session, [320, 320], image)
    result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
    result_map_original = _normalize_skyseg_output(result_map_original)
    output_mask = np.zeros(result_map_original.shape, dtype=np.uint8)
    output_mask[result_map_original < SKYSEG_THRESHOLD] = 255
    if mask_filename is not None:
        os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
        cv2.imwrite(mask_filename, output_mask)
    return output_mask


def compute_sky_mask(image_paths, model_path: str, target_dir: str = None):
    if onnxruntime is None:
        return None
    if not os.path.exists(model_path):
        os.makedirs(os.path.dirname(os.path.abspath(model_path)), exist_ok=True)
        try:
            print(f"[longstream] downloading skyseg.onnx to {model_path}", flush=True)
            with urllib.request.urlopen(SKYSEG_URL) as src, open(
                model_path, "wb"
            ) as dst:
                shutil.copyfileobj(src, dst)
        except Exception as exc:
            print(f"[longstream] failed to download skyseg.onnx: {exc}", flush=True)
            return None
    if not os.path.exists(model_path):
        return None
    session = onnxruntime.InferenceSession(model_path)
    masks = []
    for image_path in image_paths:
        mask_filepath = None
        if target_dir is not None:
            name = sky_mask_filename(image_path)
            mask_filepath = os.path.join(target_dir, name)
            if os.path.exists(mask_filepath):
                sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
            else:
                sky_mask = segment_sky(image_path, session, mask_filepath)
        else:
            sky_mask = segment_sky(image_path, session, None)
        masks.append(sky_mask)
    return masks