| | """ |
| | Test unified SegEarth pipeline on a sample image. |
| | Usage: |
| | python test_demo.py |
| | python test_demo.py --variant ov2_alignearth_sar |
| | python test_demo.py --variant ov_clip_openai_vitb16 --featup bilinear |
| | python test_demo.py --variant ov3_sam3 # requires sam3 package |
| | """ |
| | import argparse |
| | from pathlib import Path |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | from pipeline import SegEarthPipeline |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--variant", |
| | default="OV-2", |
| | choices=["OV", "OV-2", "OV-3"], |
| | help="SegEarth variant (subfolder)", |
| | ) |
| | parser.add_argument("--image", type=str, help="Input image path") |
| | parser.add_argument("--config", type=str, default="configs/cls_openearthmap_sar.txt", help="Class config") |
| | parser.add_argument("--featup", default=None, help="Override featup: jbu_one, bilinear, etc.") |
| | parser.add_argument("--prob-thd", type=float, default=0.0, help="Low-confidence threshold") |
| | parser.add_argument("--save", type=str, help="Save output figure") |
| | args = parser.parse_args() |
| |
|
| | repo_dir = Path(__file__).resolve().parent |
| | config_path = repo_dir / args.config if not Path(args.config).is_absolute() else Path(args.config) |
| | config_path = config_path if config_path.exists() else repo_dir / args.variant / "configs" / Path(args.config).name |
| |
|
| | |
| | demo_dirs = [ |
| | repo_dir / "demo_YESeg-OPT-SAR", |
| | repo_dir.parent / "AlignEarth-SAR-ViT-B-16" / "demo_YESeg-OPT-SAR", |
| | ] |
| | image_path = None |
| | if args.image: |
| | image_path = Path(args.image) |
| | else: |
| | for d in demo_dirs: |
| | p = d / "sar.png" |
| | if p.exists(): |
| | image_path = p |
| | break |
| | if image_path is None: |
| | image_path = repo_dir / "demo.png" |
| |
|
| | if not image_path or not image_path.exists(): |
| | print("No image found. Use --image path/to/image.png") |
| | print("Or place demo.png or demo_YESeg-OPT-SAR/sar.png in the repo.") |
| | return |
| |
|
| | kwargs = {"prob_thd": args.prob_thd, "device": "cuda"} |
| | if config_path.exists(): |
| | kwargs["class_names_path"] = config_path |
| | if args.featup: |
| | kwargs["featup_model"] = args.featup |
| |
|
| | print(f"Loading pipeline (variant={args.variant})...") |
| | pipe = SegEarthPipeline(variant=args.variant, **kwargs) |
| |
|
| | print(f"Running segmentation on {image_path}...") |
| | image = Image.open(image_path).convert("RGB") |
| | seg_pred = pipe(image) |
| |
|
| | print(f"Output shape: {seg_pred.shape}") |
| | print(f"Classes present: {seg_pred.unique().tolist()}") |
| |
|
| | seg_np = seg_pred.cpu().numpy() |
| | n_classes = pipe.num_classes if hasattr(pipe, "num_classes") else int(seg_np.max()) + 1 |
| |
|
| | fig, axes = plt.subplots(1, 2, figsize=(12, 6)) |
| | axes[0].imshow(image) |
| | axes[0].set_title("Input") |
| | axes[0].axis("off") |
| | axes[1].imshow(seg_np, cmap="tab10", vmin=0, vmax=max(n_classes - 1, 9)) |
| | axes[1].set_title("Prediction") |
| | axes[1].axis("off") |
| | plt.tight_layout() |
| |
|
| | if args.save: |
| | out_path = Path(args.save) |
| | plt.savefig(out_path, bbox_inches="tight") |
| | print(f"Saved to {out_path}") |
| | plt.show() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|