""" Test unified SegEarth pipeline on a sample image. Usage: python test_demo.py python test_demo.py --variant ov2_alignearth_sar python test_demo.py --variant ov_clip_openai_vitb16 --featup bilinear python test_demo.py --variant ov3_sam3 # requires sam3 package """ import argparse from pathlib import Path import matplotlib.pyplot as plt import numpy as np from PIL import Image from pipeline import SegEarthPipeline def main(): parser = argparse.ArgumentParser() parser.add_argument( "--variant", default="OV-2", choices=["OV", "OV-2", "OV-3"], help="SegEarth variant (subfolder)", ) parser.add_argument("--image", type=str, help="Input image path") parser.add_argument("--config", type=str, default="configs/cls_openearthmap_sar.txt", help="Class config") parser.add_argument("--featup", default=None, help="Override featup: jbu_one, bilinear, etc.") parser.add_argument("--prob-thd", type=float, default=0.0, help="Low-confidence threshold") parser.add_argument("--save", type=str, help="Save output figure") args = parser.parse_args() repo_dir = Path(__file__).resolve().parent config_path = repo_dir / args.config if not Path(args.config).is_absolute() else Path(args.config) config_path = config_path if config_path.exists() else repo_dir / args.variant / "configs" / Path(args.config).name # Default demo image: use AlignEarth demo if available demo_dirs = [ repo_dir / "demo_YESeg-OPT-SAR", repo_dir.parent / "AlignEarth-SAR-ViT-B-16" / "demo_YESeg-OPT-SAR", ] image_path = None if args.image: image_path = Path(args.image) else: for d in demo_dirs: p = d / "sar.png" if p.exists(): image_path = p break if image_path is None: image_path = repo_dir / "demo.png" if not image_path or not image_path.exists(): print("No image found. Use --image path/to/image.png") print("Or place demo.png or demo_YESeg-OPT-SAR/sar.png in the repo.") return kwargs = {"prob_thd": args.prob_thd, "device": "cuda"} if config_path.exists(): kwargs["class_names_path"] = config_path if args.featup: kwargs["featup_model"] = args.featup print(f"Loading pipeline (variant={args.variant})...") pipe = SegEarthPipeline(variant=args.variant, **kwargs) print(f"Running segmentation on {image_path}...") image = Image.open(image_path).convert("RGB") seg_pred = pipe(image) print(f"Output shape: {seg_pred.shape}") print(f"Classes present: {seg_pred.unique().tolist()}") seg_np = seg_pred.cpu().numpy() n_classes = pipe.num_classes if hasattr(pipe, "num_classes") else int(seg_np.max()) + 1 fig, axes = plt.subplots(1, 2, figsize=(12, 6)) axes[0].imshow(image) axes[0].set_title("Input") axes[0].axis("off") axes[1].imshow(seg_np, cmap="tab10", vmin=0, vmax=max(n_classes - 1, 9)) axes[1].set_title("Prediction") axes[1].axis("off") plt.tight_layout() if args.save: out_path = Path(args.save) plt.savefig(out_path, bbox_inches="tight") print(f"Saved to {out_path}") plt.show() if __name__ == "__main__": main()