SegEarth-OV / test_demo.py
BiliSakura's picture
Update all files for SegEarth-OV
9785963 verified
"""
Test unified SegEarth pipeline on a sample image.
Usage:
python test_demo.py
python test_demo.py --variant ov2_alignearth_sar
python test_demo.py --variant ov_clip_openai_vitb16 --featup bilinear
python test_demo.py --variant ov3_sam3 # requires sam3 package
"""
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from pipeline import SegEarthPipeline
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--variant",
default="OV-2",
choices=["OV", "OV-2", "OV-3"],
help="SegEarth variant (subfolder)",
)
parser.add_argument("--image", type=str, help="Input image path")
parser.add_argument("--config", type=str, default="configs/cls_openearthmap_sar.txt", help="Class config")
parser.add_argument("--featup", default=None, help="Override featup: jbu_one, bilinear, etc.")
parser.add_argument("--prob-thd", type=float, default=0.0, help="Low-confidence threshold")
parser.add_argument("--save", type=str, help="Save output figure")
args = parser.parse_args()
repo_dir = Path(__file__).resolve().parent
config_path = repo_dir / args.config if not Path(args.config).is_absolute() else Path(args.config)
config_path = config_path if config_path.exists() else repo_dir / args.variant / "configs" / Path(args.config).name
# Default demo image: use AlignEarth demo if available
demo_dirs = [
repo_dir / "demo_YESeg-OPT-SAR",
repo_dir.parent / "AlignEarth-SAR-ViT-B-16" / "demo_YESeg-OPT-SAR",
]
image_path = None
if args.image:
image_path = Path(args.image)
else:
for d in demo_dirs:
p = d / "sar.png"
if p.exists():
image_path = p
break
if image_path is None:
image_path = repo_dir / "demo.png"
if not image_path or not image_path.exists():
print("No image found. Use --image path/to/image.png")
print("Or place demo.png or demo_YESeg-OPT-SAR/sar.png in the repo.")
return
kwargs = {"prob_thd": args.prob_thd, "device": "cuda"}
if config_path.exists():
kwargs["class_names_path"] = config_path
if args.featup:
kwargs["featup_model"] = args.featup
print(f"Loading pipeline (variant={args.variant})...")
pipe = SegEarthPipeline(variant=args.variant, **kwargs)
print(f"Running segmentation on {image_path}...")
image = Image.open(image_path).convert("RGB")
seg_pred = pipe(image)
print(f"Output shape: {seg_pred.shape}")
print(f"Classes present: {seg_pred.unique().tolist()}")
seg_np = seg_pred.cpu().numpy()
n_classes = pipe.num_classes if hasattr(pipe, "num_classes") else int(seg_np.max()) + 1
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(image)
axes[0].set_title("Input")
axes[0].axis("off")
axes[1].imshow(seg_np, cmap="tab10", vmin=0, vmax=max(n_classes - 1, 9))
axes[1].set_title("Prediction")
axes[1].axis("off")
plt.tight_layout()
if args.save:
out_path = Path(args.save)
plt.savefig(out_path, bbox_inches="tight")
print(f"Saved to {out_path}")
plt.show()
if __name__ == "__main__":
main()