Update all files for SegEarth-OV
Browse files- test_demo.py +96 -0
test_demo.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test unified SegEarth pipeline on a sample image.
|
| 3 |
+
Usage:
|
| 4 |
+
python test_demo.py
|
| 5 |
+
python test_demo.py --variant ov2_alignearth_sar
|
| 6 |
+
python test_demo.py --variant ov_clip_openai_vitb16 --featup bilinear
|
| 7 |
+
python test_demo.py --variant ov3_sam3 # requires sam3 package
|
| 8 |
+
"""
|
| 9 |
+
import argparse
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from pipeline import SegEarthPipeline
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--variant",
|
| 23 |
+
default="OV-2",
|
| 24 |
+
choices=["OV", "OV-2", "OV-3"],
|
| 25 |
+
help="SegEarth variant (subfolder)",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--image", type=str, help="Input image path")
|
| 28 |
+
parser.add_argument("--config", type=str, default="configs/cls_openearthmap_sar.txt", help="Class config")
|
| 29 |
+
parser.add_argument("--featup", default=None, help="Override featup: jbu_one, bilinear, etc.")
|
| 30 |
+
parser.add_argument("--prob-thd", type=float, default=0.0, help="Low-confidence threshold")
|
| 31 |
+
parser.add_argument("--save", type=str, help="Save output figure")
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
|
| 34 |
+
repo_dir = Path(__file__).resolve().parent
|
| 35 |
+
config_path = repo_dir / args.config if not Path(args.config).is_absolute() else Path(args.config)
|
| 36 |
+
config_path = config_path if config_path.exists() else repo_dir / args.variant / "configs" / Path(args.config).name
|
| 37 |
+
|
| 38 |
+
# Default demo image: use AlignEarth demo if available
|
| 39 |
+
demo_dirs = [
|
| 40 |
+
repo_dir / "demo_YESeg-OPT-SAR",
|
| 41 |
+
repo_dir.parent / "AlignEarth-SAR-ViT-B-16" / "demo_YESeg-OPT-SAR",
|
| 42 |
+
]
|
| 43 |
+
image_path = None
|
| 44 |
+
if args.image:
|
| 45 |
+
image_path = Path(args.image)
|
| 46 |
+
else:
|
| 47 |
+
for d in demo_dirs:
|
| 48 |
+
p = d / "sar.png"
|
| 49 |
+
if p.exists():
|
| 50 |
+
image_path = p
|
| 51 |
+
break
|
| 52 |
+
if image_path is None:
|
| 53 |
+
image_path = repo_dir / "demo.png"
|
| 54 |
+
|
| 55 |
+
if not image_path or not image_path.exists():
|
| 56 |
+
print("No image found. Use --image path/to/image.png")
|
| 57 |
+
print("Or place demo.png or demo_YESeg-OPT-SAR/sar.png in the repo.")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
kwargs = {"prob_thd": args.prob_thd, "device": "cuda"}
|
| 61 |
+
if config_path.exists():
|
| 62 |
+
kwargs["class_names_path"] = config_path
|
| 63 |
+
if args.featup:
|
| 64 |
+
kwargs["featup_model"] = args.featup
|
| 65 |
+
|
| 66 |
+
print(f"Loading pipeline (variant={args.variant})...")
|
| 67 |
+
pipe = SegEarthPipeline(variant=args.variant, **kwargs)
|
| 68 |
+
|
| 69 |
+
print(f"Running segmentation on {image_path}...")
|
| 70 |
+
image = Image.open(image_path).convert("RGB")
|
| 71 |
+
seg_pred = pipe(image)
|
| 72 |
+
|
| 73 |
+
print(f"Output shape: {seg_pred.shape}")
|
| 74 |
+
print(f"Classes present: {seg_pred.unique().tolist()}")
|
| 75 |
+
|
| 76 |
+
seg_np = seg_pred.cpu().numpy()
|
| 77 |
+
n_classes = pipe.num_classes if hasattr(pipe, "num_classes") else int(seg_np.max()) + 1
|
| 78 |
+
|
| 79 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
|
| 80 |
+
axes[0].imshow(image)
|
| 81 |
+
axes[0].set_title("Input")
|
| 82 |
+
axes[0].axis("off")
|
| 83 |
+
axes[1].imshow(seg_np, cmap="tab10", vmin=0, vmax=max(n_classes - 1, 9))
|
| 84 |
+
axes[1].set_title("Prediction")
|
| 85 |
+
axes[1].axis("off")
|
| 86 |
+
plt.tight_layout()
|
| 87 |
+
|
| 88 |
+
if args.save:
|
| 89 |
+
out_path = Path(args.save)
|
| 90 |
+
plt.savefig(out_path, bbox_inches="tight")
|
| 91 |
+
print(f"Saved to {out_path}")
|
| 92 |
+
plt.show()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
main()
|