File size: 10,798 Bytes
75f0bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# batch_select_nearest_touching_sidewalk.py
import argparse
import os
import glob
import csv
import pickle
import numpy as np
import cv2

def load_mask_stack(pkl_path):
    with open(pkl_path, "rb") as f:
        masks = pickle.load(f)
    masks = np.asarray(masks)
    if masks.ndim != 3:
        raise ValueError(f"Expected (N,H,W), got {masks.shape} from {pkl_path}")
    return masks.astype(bool)

def load_sidewalk_mask(path):
    """
    Supports:
      - .pkl: pickle of (H,W) or (N,H,W)
      - .npy: numpy array (H,W) or (N,H,W)
      - .png: image mask (nonzero => True)
    For (N,H,W), returns union over N.
    """
    ext = os.path.splitext(path)[1].lower()

    if ext == ".pkl":
        with open(path, "rb") as f:
            arr = pickle.load(f)
        arr = np.asarray(arr)
    elif ext == ".npy":
        arr = np.load(path)
    elif ext in [".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"]:
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        if img is None:
            raise FileNotFoundError(f"Failed to read sidewalk image: {path}")
        if img.ndim == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        arr = (img > 0).astype(bool)
        return arr
    else:
        raise ValueError(f"Unsupported sidewalk mask extension: {ext} ({path})")

    if arr.ndim == 3:
        return np.any(arr.astype(bool), axis=0)
    if arr.ndim == 2:
        return arr.astype(bool)
    raise ValueError(f"Unexpected sidewalk mask shape {arr.shape} from {path}")

def robust_depth(depth, mask, q=10.0):
    valid = np.isfinite(depth) & (depth > 0)
    pix = depth[mask & valid]
    if pix.size == 0:
        return np.inf
    return float(np.percentile(pix, q))

def touches_sidewalk(obj_mask, sidewalk_mask, margin_px=8, min_contact_px=30, use_boundary=False):
    sidewalk_u8 = sidewalk_mask.astype(np.uint8) * 255
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*margin_px+1, 2*margin_px+1))
    sidewalk_dil = cv2.dilate(sidewalk_u8, k, iterations=1) > 0

    if not use_boundary:
        contact = obj_mask & sidewalk_dil
        return int(contact.sum()) >= min_contact_px

    # boundary-touch version
    m = obj_mask.astype(np.uint8) * 255
    kb = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    er = cv2.erode(m, kb, iterations=1)
    bd = (m > 0) & (er == 0)
    return int((bd & sidewalk_dil).sum()) >= max(5, min_contact_px // 3)

def find_matching_file(stem, folder, exts, allow_contains=False):
    """
    Tries:
      1) exact: folder/stem + ext
      2) if allow_contains: glob folder/*stem*ext  (first match)
    """
    for ext in exts:
        p = os.path.join(folder, stem + ext)
        if os.path.exists(p):
            return p
    if allow_contains:
        for ext in exts:
            hits = sorted(glob.glob(os.path.join(folder, f"*{stem}*{ext}")))
            if hits:
                return hits[0]
    return None

def overlay_mask(rgb_bgr, mask_bool, alpha=0.4):
    overlay = rgb_bgr.copy()
    red = np.array([0, 0, 255], dtype=np.uint8)  # BGR
    m = mask_bool
    overlay[m] = (0.6 * overlay[m] + 0.4 * red).astype(np.uint8)
    return overlay

def process_one(rgb_path, depth_path, masks_path, sidewalk_path, out_dir, args):
    rgb = cv2.imread(rgb_path)
    if rgb is None:
        return {"status": "fail", "reason": "rgb_read_failed"}

    depth = np.load(depth_path)
    masks = load_mask_stack(masks_path)
    sidewalk = load_sidewalk_mask(sidewalk_path)

    if depth.shape != masks.shape[1:]:
        return {"status": "fail", "reason": f"shape_mismatch_depth_vs_masks depth={depth.shape} masksHW={masks.shape[1:]}"}
    if rgb.shape[:2] != depth.shape:
        return {"status": "fail", "reason": f"shape_mismatch_rgb_vs_depth rgbHW={rgb.shape[:2]} depth={depth.shape}"}
    if sidewalk.shape != depth.shape:
        return {"status": "fail", "reason": f"shape_mismatch_sidewalk_vs_depth sidewalk={sidewalk.shape} depth={depth.shape}"}

    best_i, best_score = None, np.inf
    kept = 0

    for i in range(masks.shape[0]):
        m = masks[i]
        if not touches_sidewalk(
            m, sidewalk,
            margin_px=args.margin_px,
            min_contact_px=args.min_contact_px,
            use_boundary=args.use_boundary
        ):
            continue
        kept += 1
        score = robust_depth(depth, m, q=args.quantile)
        if score < best_score:
            best_score = score
            best_i = i

    os.makedirs(out_dir, exist_ok=True)

    if best_i is None:
        # still save a quick note file for debugging
        with open(os.path.join(out_dir, "no_match.txt"), "w") as f:
            f.write(f"No object touching sidewalk found. total_masks={masks.shape[0]} kept_after_touch={kept}\n")
        return {"status": "no_match", "reason": f"no_touching_object kept={kept}/{masks.shape[0]}"}

    nearest_mask = masks[best_i]
    mask_png = os.path.join(out_dir, "nearest_mask.png")
    overlay_png = os.path.join(out_dir, "nearest_overlay.png")

    cv2.imwrite(mask_png, nearest_mask.astype(np.uint8) * 255)
    cv2.imwrite(overlay_png, overlay_mask(rgb, nearest_mask, alpha=args.overlay_alpha))

    # optional: save index + score
    with open(os.path.join(out_dir, "nearest_meta.txt"), "w") as f:
        f.write(f"best_i={best_i}\n")
        f.write(f"depth_score_p{args.quantile:g}={best_score}\n")
        f.write(f"total_masks={masks.shape[0]}\n")
        f.write(f"kept_after_touch={kept}\n")

    return {
        "status": "ok",
        "best_i": int(best_i),
        "depth_score": float(best_score),
        "total_masks": int(masks.shape[0]),
        "kept_after_touch": int(kept),
        "mask_png": mask_png,
        "overlay_png": overlay_png,
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--rgb_dir", default="/scratch/ds5725/OOPS/images_resized",
                    help="folder with RGB images (png/jpg)")
    ap.add_argument("--depth_dir", default="/scratch/ds5725/OOPS/depthpro_out",
                    help="folder with depth .npy")
    ap.add_argument("--masks_dir", default="/scratch/ds5725/sam3/object_union_batch",
                    help="folder with object masks .pkl")
    ap.add_argument("--sidewalk_dir", default="/scratch/ds5725/sam3/batch_surface",
                    help="folder with sidewalk masks (.pkl/.npy/.png)")

    ap.add_argument("--out_dir", default="./nearest_out",
                    help="output root folder")
    ap.add_argument("--rgb_exts", nargs="+", default=[".png", ".jpg", ".jpeg"],
                    help="RGB extensions to scan")
    ap.add_argument("--quantile", type=float, default=10.0)
    ap.add_argument("--margin_px", type=int, default=8)
    ap.add_argument("--min_contact_px", type=int, default=30)
    ap.add_argument("--use_boundary", action="store_true",
                    help="use boundary-touch instead of mask-touch")
    ap.add_argument("--overlay_alpha", type=float, default=0.4)

    # matching behavior
    ap.add_argument("--allow_contains_match", action="store_true",
                    help="if exact stem.ext not found, try *stem*.ext glob in depth/masks/sidewalk dirs")

    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)

    # gather RGBs
    rgb_paths = []
    for ext in args.rgb_exts:
        rgb_paths.extend(sorted(glob.glob(os.path.join(args.rgb_dir, f"*{ext}"))))
    rgb_paths = sorted(set(rgb_paths))

    if not rgb_paths:
        raise FileNotFoundError(f"No RGB images found in {args.rgb_dir} with exts {args.rgb_exts}")

    summary_csv = os.path.join(args.out_dir, "summary.csv")
    rows = []

    for rgb_path in rgb_paths:
        fname = os.path.basename(rgb_path)
        stem = os.path.splitext(fname)[0]

        depth_path = find_matching_file(stem, args.depth_dir, exts=[".npy"], allow_contains=args.allow_contains_match)
        masks_path = find_matching_file(stem, args.masks_dir, exts=[".pkl"], allow_contains=args.allow_contains_match)

        # sidewalk could be pkl/npy/png; try in that order
        sidewalk_path = find_matching_file(stem, args.sidewalk_dir, exts=[".pkl"], allow_contains=args.allow_contains_match)

        out_subdir = os.path.join(args.out_dir, stem)

        missing = []
        if depth_path is None: missing.append("depth")
        if masks_path is None: missing.append("masks")
        if sidewalk_path is None: missing.append("sidewalk")

        if missing:
            rows.append({
                "image": fname,
                "stem": stem,
                "status": "skip_missing_inputs",
                "reason": "missing_" + ",".join(missing),
                "depth_path": depth_path or "",
                "masks_path": masks_path or "",
                "sidewalk_path": sidewalk_path or "",
                "best_i": "",
                "depth_score": "",
                "total_masks": "",
                "kept_after_touch": "",
                "overlay_png": "",
            })
            continue

        try:
            res = process_one(rgb_path, depth_path, masks_path, sidewalk_path, out_subdir, args)
            rows.append({
                "image": fname,
                "stem": stem,
                "status": res.get("status", ""),
                "reason": res.get("reason", ""),
                "depth_path": depth_path,
                "masks_path": masks_path,
                "sidewalk_path": sidewalk_path,
                "best_i": res.get("best_i", ""),
                "depth_score": res.get("depth_score", ""),
                "total_masks": res.get("total_masks", ""),
                "kept_after_touch": res.get("kept_after_touch", ""),
                "overlay_png": res.get("overlay_png", ""),
            })
        except Exception as e:
            rows.append({
                "image": fname,
                "stem": stem,
                "status": "fail_exception",
                "reason": repr(e),
                "depth_path": depth_path,
                "masks_path": masks_path,
                "sidewalk_path": sidewalk_path,
                "best_i": "",
                "depth_score": "",
                "total_masks": "",
                "kept_after_touch": "",
                "overlay_png": "",
            })

    # write CSV
    fieldnames = [
        "image","stem","status","reason",
        "depth_path","masks_path","sidewalk_path",
        "best_i","depth_score","total_masks","kept_after_touch","overlay_png"
    ]
    with open(summary_csv, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for r in rows:
            w.writerow(r)

    print(f"Done. Wrote summary: {summary_csv}")
    print(f"Outputs per image are in: {args.out_dir}/<stem>/nearest_mask.png and nearest_overlay.png")

if __name__ == "__main__":
    main()