| import argparse |
| import csv |
| import json |
| import os |
| import re |
| from typing import Dict, List, Optional, Sequence, Tuple |
|
|
|
|
| COLORS = ("gray", "grey", "red", "blue", "green", "brown", "purple", "cyan", "yellow") |
| SHAPES = ("cube", "cubes", "sphere", "spheres", "cylinder", "cylinders") |
| MATERIALS = ("metal", "metals", "rubber", "rubbers") |
| SIZES = ("small", "large") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Prune mapping CSV rows if any scene object is occluded. " |
| "Default behavior is absolute-safety mode: drop row if occlusion metadata " |
| "is missing or any object fails visibility threshold." |
| ) |
| ) |
| parser.add_argument( |
| "--run_dir", |
| required=True, |
| help="Path to a run directory containing the CSV and scenes/ folder.", |
| ) |
| parser.add_argument( |
| "--input_csv", |
| default="image_mapping_with_questions_strict_cf.csv", |
| help="Input CSV filename inside run_dir.", |
| ) |
| parser.add_argument( |
| "--output_csv", |
| default="image_mapping_with_questions_pruned.csv", |
| help="Output CSV filename inside run_dir.", |
| ) |
| parser.add_argument( |
| "--scenes_dir", |
| default=None, |
| help="Optional explicit scenes directory (defaults to <run_dir>/scenes).", |
| ) |
| parser.add_argument( |
| "--min_visible_pixels", |
| type=int, |
| default=50, |
| help="Minimum visible pixels threshold for an object to be considered visible.", |
| ) |
| parser.add_argument( |
| "--min_visibility_fraction", |
| type=float, |
| default=None, |
| help="Optional minimum visibility fraction threshold.", |
| ) |
| parser.add_argument( |
| "--question_conditioned", |
| action="store_true", |
| help=( |
| "Only prune when a failing object matches attributes mentioned in question columns. " |
| "Default (off) is safer: prune if any object fails." |
| ), |
| ) |
| parser.add_argument( |
| "--keep_missing_visibility", |
| action="store_true", |
| help=( |
| "Keep rows where no usable visibility metadata is found. " |
| "Default (off) is safer: drop rows with missing visibility metadata." |
| ), |
| ) |
| return parser.parse_args() |
|
|
|
|
| def load_json(path: str) -> Optional[Dict]: |
| try: |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except Exception as e: |
| print(f"[WARN] Failed to read JSON: {path} ({e})") |
| return None |
|
|
|
|
| def looks_like_scene_image(name: str) -> bool: |
| return bool(name) and name.endswith(".png") and name.startswith("scene_") |
|
|
|
|
| def scene_json_from_row(row: Dict[str, str], scenes_dir: str) -> Optional[str]: |
| candidate_fields = ( |
| "original_scene_link", |
| "original_scene", |
| "original_image", |
| "original_image_link", |
| ) |
|
|
| for field in candidate_fields: |
| val = (row.get(field) or "").strip() |
| if not val: |
| continue |
| if val.endswith(".json"): |
| base = os.path.basename(val) |
| return os.path.join(scenes_dir, base) |
| if looks_like_scene_image(os.path.basename(val)): |
| base = os.path.basename(val).replace(".png", ".json") |
| return os.path.join(scenes_dir, base) |
|
|
| for _, val in row.items(): |
| sval = (val or "").strip() |
| base = os.path.basename(sval) |
| if looks_like_scene_image(base) and "_original.png" in base: |
| return os.path.join(scenes_dir, base.replace(".png", ".json")) |
| return None |
|
|
|
|
| def normalize_token(s: str) -> str: |
| return (s or "").strip().lower() |
|
|
|
|
| def object_fails_visibility( |
| obj: Dict, |
| min_visible_pixels: int, |
| min_visibility_fraction: Optional[float], |
| ) -> Tuple[bool, bool]: |
| """ |
| Returns: |
| (fails_threshold, has_any_visibility_metadata) |
| """ |
| is_occluded = bool(obj.get("is_occluded", False)) |
| visible_flag = obj.get("visible", None) |
| visible_pixels = obj.get("visible_pixel_count", None) |
| visibility_fraction = obj.get("visibility_fraction", None) |
|
|
| has_metadata = any( |
| k in obj for k in ("is_occluded", "visible", "visible_pixel_count", "visibility_fraction") |
| ) |
|
|
| if is_occluded: |
| return True, has_metadata |
| if visible_flag is False: |
| return True, has_metadata |
|
|
| if visible_pixels is not None: |
| try: |
| if int(visible_pixels) < int(min_visible_pixels): |
| return True, True |
| except Exception: |
| pass |
|
|
| if min_visibility_fraction is not None and visibility_fraction is not None: |
| try: |
| if float(visibility_fraction) < float(min_visibility_fraction): |
| return True, True |
| except Exception: |
| pass |
|
|
| return False, has_metadata |
|
|
|
|
| def mentioned_attributes(question_text: str) -> Dict[str, set]: |
| q = normalize_token(question_text) |
| tokens = re.findall(r"[a-z]+", q) |
| tok_set = set(tokens) |
| return { |
| "color": {c.rstrip("s") for c in COLORS if c in tok_set}, |
| "shape": {s.rstrip("s") for s in SHAPES if s in tok_set}, |
| "material": {m.rstrip("s") for m in MATERIALS if m in tok_set}, |
| "size": {z for z in SIZES if z in tok_set}, |
| } |
|
|
|
|
| def row_question_text(row: Dict[str, str]) -> str: |
| parts: List[str] = [] |
| for k, v in row.items(): |
| if "question" in (k or "").lower(): |
| vv = (v or "").strip() |
| if vv: |
| parts.append(vv.lower()) |
| return " ".join(parts) |
|
|
|
|
| def obj_matches_mentioned_attrs(obj: Dict, mentioned: Dict[str, set]) -> bool: |
| c = normalize_token(obj.get("color", "")) |
| s = normalize_token(obj.get("shape", "")).rstrip("s") |
| m = normalize_token(obj.get("material", "")).rstrip("s") |
| z = normalize_token(obj.get("size", "")) |
| if mentioned["color"] and c in mentioned["color"]: |
| return True |
| if mentioned["shape"] and s in mentioned["shape"]: |
| return True |
| if mentioned["material"] and m in mentioned["material"]: |
| return True |
| if mentioned["size"] and z in mentioned["size"]: |
| return True |
| return False |
|
|
|
|
| def should_prune_row( |
| row: Dict[str, str], |
| scene: Dict, |
| min_visible_pixels: int, |
| min_visibility_fraction: Optional[float], |
| question_conditioned: bool, |
| keep_missing_visibility: bool, |
| ) -> Tuple[bool, str]: |
| objects = scene.get("objects", []) or [] |
| if not isinstance(objects, list) or not objects: |
| return True, "empty_or_invalid_scene_objects" |
|
|
| failing_objects: List[Dict] = [] |
| any_metadata = False |
| for obj in objects: |
| fails, has_meta = object_fails_visibility(obj, min_visible_pixels, min_visibility_fraction) |
| any_metadata = any_metadata or has_meta |
| if fails: |
| failing_objects.append(obj) |
|
|
| if not any_metadata and not keep_missing_visibility: |
| return True, "missing_visibility_metadata" |
| if not failing_objects: |
| return False, "ok" |
|
|
| if not question_conditioned: |
| return True, f"occluded_objects={len(failing_objects)}" |
|
|
| q_text = row_question_text(row) |
| mentioned = mentioned_attributes(q_text) |
| for obj in failing_objects: |
| if obj_matches_mentioned_attrs(obj, mentioned): |
| return True, "occluded_object_matches_question_attributes" |
| return False, "occluded_but_no_question_attribute_overlap" |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| run_dir = os.path.abspath(args.run_dir) |
| scenes_dir = os.path.abspath(args.scenes_dir) if args.scenes_dir else os.path.join(run_dir, "scenes") |
| input_csv_path = os.path.join(run_dir, args.input_csv) |
| output_csv_path = os.path.join(run_dir, args.output_csv) |
|
|
| if not os.path.isfile(input_csv_path): |
| raise FileNotFoundError(f"Input CSV not found: {input_csv_path}") |
| if not os.path.isdir(scenes_dir): |
| raise FileNotFoundError(f"Scenes directory not found: {scenes_dir}") |
|
|
| total = 0 |
| kept = 0 |
| pruned = 0 |
| reasons: Dict[str, int] = {} |
| out_rows: List[Dict[str, str]] = [] |
|
|
| with open(input_csv_path, "r", encoding="utf-8", newline="") as f: |
| reader = csv.DictReader(f) |
| fieldnames = reader.fieldnames |
| if not fieldnames: |
| raise RuntimeError("Input CSV has no header.") |
|
|
| for row in reader: |
| total += 1 |
| scene_json = scene_json_from_row(row, scenes_dir) |
| if not scene_json or not os.path.isfile(scene_json): |
| pruned += 1 |
| reasons["scene_file_missing"] = reasons.get("scene_file_missing", 0) + 1 |
| if total % 50 == 0: |
| print(f"[PROGRESS] processed={total} kept={kept} pruned={pruned}") |
| continue |
|
|
| scene = load_json(scene_json) |
| if scene is None: |
| pruned += 1 |
| reasons["scene_json_unreadable"] = reasons.get("scene_json_unreadable", 0) + 1 |
| if total % 50 == 0: |
| print(f"[PROGRESS] processed={total} kept={kept} pruned={pruned}") |
| continue |
|
|
| drop, reason = should_prune_row( |
| row=row, |
| scene=scene, |
| min_visible_pixels=args.min_visible_pixels, |
| min_visibility_fraction=args.min_visibility_fraction, |
| question_conditioned=args.question_conditioned, |
| keep_missing_visibility=args.keep_missing_visibility, |
| ) |
|
|
| if drop: |
| pruned += 1 |
| reasons[reason] = reasons.get(reason, 0) + 1 |
| else: |
| kept += 1 |
| out_rows.append(row) |
|
|
| if total % 50 == 0: |
| print(f"[PROGRESS] processed={total} kept={kept} pruned={pruned}") |
|
|
| with open(output_csv_path, "w", encoding="utf-8", newline="") as f: |
| writer = csv.DictWriter(f, fieldnames=fieldnames, quoting=csv.QUOTE_ALL) |
| writer.writeheader() |
| writer.writerows(out_rows) |
|
|
| print("\n[DONE] Pruning complete") |
| print(f"Input CSV: {input_csv_path}") |
| print(f"Output CSV: {output_csv_path}") |
| print(f"Total rows: {total}") |
| print(f"Kept rows: {kept}") |
| print(f"Pruned: {pruned}") |
| if reasons: |
| print("Prune reasons:") |
| for k in sorted(reasons.keys()): |
| print(f" - {k}: {reasons[k]}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|