BiliSakura commited on
Commit
9785963
·
verified ·
1 Parent(s): ee5c69c

Update all files for SegEarth-OV

Browse files
Files changed (1) hide show
  1. test_demo.py +96 -0
test_demo.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test unified SegEarth pipeline on a sample image.
3
+ Usage:
4
+ python test_demo.py
5
+ python test_demo.py --variant ov2_alignearth_sar
6
+ python test_demo.py --variant ov_clip_openai_vitb16 --featup bilinear
7
+ python test_demo.py --variant ov3_sam3 # requires sam3 package
8
+ """
9
+ import argparse
10
+ from pathlib import Path
11
+
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from PIL import Image
15
+
16
+ from pipeline import SegEarthPipeline
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument(
22
+ "--variant",
23
+ default="OV-2",
24
+ choices=["OV", "OV-2", "OV-3"],
25
+ help="SegEarth variant (subfolder)",
26
+ )
27
+ parser.add_argument("--image", type=str, help="Input image path")
28
+ parser.add_argument("--config", type=str, default="configs/cls_openearthmap_sar.txt", help="Class config")
29
+ parser.add_argument("--featup", default=None, help="Override featup: jbu_one, bilinear, etc.")
30
+ parser.add_argument("--prob-thd", type=float, default=0.0, help="Low-confidence threshold")
31
+ parser.add_argument("--save", type=str, help="Save output figure")
32
+ args = parser.parse_args()
33
+
34
+ repo_dir = Path(__file__).resolve().parent
35
+ config_path = repo_dir / args.config if not Path(args.config).is_absolute() else Path(args.config)
36
+ config_path = config_path if config_path.exists() else repo_dir / args.variant / "configs" / Path(args.config).name
37
+
38
+ # Default demo image: use AlignEarth demo if available
39
+ demo_dirs = [
40
+ repo_dir / "demo_YESeg-OPT-SAR",
41
+ repo_dir.parent / "AlignEarth-SAR-ViT-B-16" / "demo_YESeg-OPT-SAR",
42
+ ]
43
+ image_path = None
44
+ if args.image:
45
+ image_path = Path(args.image)
46
+ else:
47
+ for d in demo_dirs:
48
+ p = d / "sar.png"
49
+ if p.exists():
50
+ image_path = p
51
+ break
52
+ if image_path is None:
53
+ image_path = repo_dir / "demo.png"
54
+
55
+ if not image_path or not image_path.exists():
56
+ print("No image found. Use --image path/to/image.png")
57
+ print("Or place demo.png or demo_YESeg-OPT-SAR/sar.png in the repo.")
58
+ return
59
+
60
+ kwargs = {"prob_thd": args.prob_thd, "device": "cuda"}
61
+ if config_path.exists():
62
+ kwargs["class_names_path"] = config_path
63
+ if args.featup:
64
+ kwargs["featup_model"] = args.featup
65
+
66
+ print(f"Loading pipeline (variant={args.variant})...")
67
+ pipe = SegEarthPipeline(variant=args.variant, **kwargs)
68
+
69
+ print(f"Running segmentation on {image_path}...")
70
+ image = Image.open(image_path).convert("RGB")
71
+ seg_pred = pipe(image)
72
+
73
+ print(f"Output shape: {seg_pred.shape}")
74
+ print(f"Classes present: {seg_pred.unique().tolist()}")
75
+
76
+ seg_np = seg_pred.cpu().numpy()
77
+ n_classes = pipe.num_classes if hasattr(pipe, "num_classes") else int(seg_np.max()) + 1
78
+
79
+ fig, axes = plt.subplots(1, 2, figsize=(12, 6))
80
+ axes[0].imshow(image)
81
+ axes[0].set_title("Input")
82
+ axes[0].axis("off")
83
+ axes[1].imshow(seg_np, cmap="tab10", vmin=0, vmax=max(n_classes - 1, 9))
84
+ axes[1].set_title("Prediction")
85
+ axes[1].axis("off")
86
+ plt.tight_layout()
87
+
88
+ if args.save:
89
+ out_path = Path(args.save)
90
+ plt.savefig(out_path, bbox_inches="tight")
91
+ print(f"Saved to {out_path}")
92
+ plt.show()
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()