(devops) automatic pull and preprocess of datasets
Browse files- .windsurf/rules/executing-python-files.md +5 -0
- gdrive_pull.py +0 -75
- requirements.txt +1 -0
- scripts/pull_and_preprocess_wireseghr_dataset.py +290 -0
- scripts/pull_ttpla.sh +0 -0
- scripts/setup_script.sh +136 -0
- scripts/ttpla_to_masks.py +14 -0
- src/wireseghr/data/dataset.py +2 -2
- src/wireseghr/data/ttpla_to_masks.py +116 -0
- tests/test_ttpla_to_masks.py +59 -0
- train.py +6 -6
.windsurf/rules/executing-python-files.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
trigger: always_on
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
When executing python files, use python3 instead of python because that adheres to project's venv. Additionally, if you haven't activated venv yet, you have to activate it or else the execution will fail with module not found exception.
|
gdrive_pull.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import argparse
|
| 3 |
-
from pydrive2.auth import GoogleAuth
|
| 4 |
-
from pydrive2.drive import GoogleDrive
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def authenticate(service_account_json):
|
| 10 |
-
"""Authenticate PyDrive2 with a service account."""
|
| 11 |
-
gauth = GoogleAuth()
|
| 12 |
-
# Configure PyDrive2 to use service account credentials directly
|
| 13 |
-
gauth.settings["client_config_backend"] = "service"
|
| 14 |
-
gauth.settings["service_config"] = {
|
| 15 |
-
"client_json_file_path": service_account_json,
|
| 16 |
-
# Provide the key to satisfy PyDrive2 even if not impersonating
|
| 17 |
-
"client_user_email": "drive-bot@web-design-396514.iam.gserviceaccount.com",
|
| 18 |
-
}
|
| 19 |
-
gauth.ServiceAuth()
|
| 20 |
-
drive = GoogleDrive(gauth)
|
| 21 |
-
return drive
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def list_files_with_paths(drive, folder_id, prefix=""):
|
| 25 |
-
"""Recursively collect all files with their relative paths from a folder."""
|
| 26 |
-
items = []
|
| 27 |
-
query = f"'{folder_id}' in parents and trashed=false"
|
| 28 |
-
for file in drive.ListFile({"q": query, "maxResults": 1000}).GetList():
|
| 29 |
-
if file["mimeType"] == "application/vnd.google-apps.folder":
|
| 30 |
-
sub_prefix = (
|
| 31 |
-
os.path.join(prefix, file["title"]) if prefix else file["title"]
|
| 32 |
-
)
|
| 33 |
-
items += list_files_with_paths(drive, file["id"], sub_prefix)
|
| 34 |
-
else:
|
| 35 |
-
rel_path = os.path.join(prefix, file["title"]) if prefix else file["title"]
|
| 36 |
-
items.append((file, rel_path))
|
| 37 |
-
return items
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def download_folder(folder_id, dest, service_account_json):
|
| 41 |
-
drive = authenticate(service_account_json)
|
| 42 |
-
os.makedirs(dest, exist_ok=True)
|
| 43 |
-
|
| 44 |
-
print(f"Listing files in folder {folder_id}...")
|
| 45 |
-
files_with_paths = list_files_with_paths(drive, folder_id)
|
| 46 |
-
print(f"Found {len(files_with_paths)} files. Downloading...")
|
| 47 |
-
|
| 48 |
-
for file, rel_path in tqdm(files_with_paths, desc="Downloading", unit="file"):
|
| 49 |
-
out_path = os.path.join(dest, rel_path)
|
| 50 |
-
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
| 51 |
-
file.GetContentFile(out_path)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def main():
|
| 55 |
-
parser = argparse.ArgumentParser(
|
| 56 |
-
description="Download a full Google Drive folder using a service account"
|
| 57 |
-
)
|
| 58 |
-
parser.add_argument("folder_id", help="Google Drive folder ID")
|
| 59 |
-
parser.add_argument("output_dir", help="Directory to save files")
|
| 60 |
-
parser.add_argument(
|
| 61 |
-
"--service-account",
|
| 62 |
-
default="service_account.json",
|
| 63 |
-
help="Path to your Google service account JSON key file",
|
| 64 |
-
)
|
| 65 |
-
args = parser.parse_args()
|
| 66 |
-
|
| 67 |
-
download_folder(args.folder_id, args.output_dir, args.service_account)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
if __name__ == "__main__":
|
| 71 |
-
# also, mkdir -p dataset/
|
| 72 |
-
path = Path("./dataset")
|
| 73 |
-
path.mkdir(exists_ok=True)
|
| 74 |
-
|
| 75 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -8,3 +8,4 @@ Pillow>=9.5.0
|
|
| 8 |
PyYAML>=6.0.1
|
| 9 |
tqdm>=4.65.0
|
| 10 |
gdown>=5.1.0
|
|
|
|
|
|
| 8 |
PyYAML>=6.0.1
|
| 9 |
tqdm>=4.65.0
|
| 10 |
gdown>=5.1.0
|
| 11 |
+
pydrive2
|
scripts/pull_and_preprocess_wireseghr_dataset.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import threading
|
| 4 |
+
import random
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
+
from pydrive2.auth import GoogleAuth
|
| 7 |
+
from pydrive2.drive import GoogleDrive
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
thread_local = threading.local()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _get_thread_drive(service_account_json: str) -> GoogleDrive:
|
| 15 |
+
d = getattr(thread_local, "drive", None)
|
| 16 |
+
if d is None:
|
| 17 |
+
d = authenticate(service_account_json)
|
| 18 |
+
thread_local.drive = d
|
| 19 |
+
return d
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def authenticate(service_account_json):
|
| 23 |
+
"""Authenticate PyDrive2 with a service account."""
|
| 24 |
+
gauth = GoogleAuth()
|
| 25 |
+
# Configure PyDrive2 to use service account credentials directly
|
| 26 |
+
gauth.settings["client_config_backend"] = "service"
|
| 27 |
+
gauth.settings["service_config"] = {
|
| 28 |
+
"client_json_file_path": service_account_json,
|
| 29 |
+
# Provide the key to satisfy PyDrive2 even if not impersonating
|
| 30 |
+
"client_user_email": "drive-bot@web-design-396514.iam.gserviceaccount.com",
|
| 31 |
+
}
|
| 32 |
+
gauth.ServiceAuth()
|
| 33 |
+
drive = GoogleDrive(gauth)
|
| 34 |
+
return drive
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def list_files_with_paths(drive, folder_id, prefix=""):
|
| 38 |
+
"""Recursively collect all files with their relative paths from a folder."""
|
| 39 |
+
items = []
|
| 40 |
+
query = f"'{folder_id}' in parents and trashed=false"
|
| 41 |
+
params = {
|
| 42 |
+
"q": query,
|
| 43 |
+
"maxResults": 1000,
|
| 44 |
+
# Request only needed fields (Drive API v2 uses 'items')
|
| 45 |
+
"fields": "items(id,title,mimeType,fileSize,md5Checksum),nextPageToken",
|
| 46 |
+
}
|
| 47 |
+
for file in drive.ListFile(params).GetList():
|
| 48 |
+
if file["mimeType"] == "application/vnd.google-apps.folder":
|
| 49 |
+
sub_prefix = (
|
| 50 |
+
os.path.join(prefix, file["title"]) if prefix else file["title"]
|
| 51 |
+
)
|
| 52 |
+
items += list_files_with_paths(drive, file["id"], sub_prefix)
|
| 53 |
+
else:
|
| 54 |
+
rel_path = os.path.join(prefix, file["title"]) if prefix else file["title"]
|
| 55 |
+
size = int(file.get("fileSize", 0)) if "fileSize" in file else 0
|
| 56 |
+
items.append(
|
| 57 |
+
{
|
| 58 |
+
"id": file["id"],
|
| 59 |
+
"rel_path": rel_path,
|
| 60 |
+
"size": size,
|
| 61 |
+
"md5": file.get("md5Checksum", ""),
|
| 62 |
+
"mimeType": file["mimeType"],
|
| 63 |
+
}
|
| 64 |
+
)
|
| 65 |
+
return items
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def download_folder(folder_id, dest, service_account_json, workers: int):
|
| 69 |
+
drive = authenticate(service_account_json)
|
| 70 |
+
os.makedirs(dest, exist_ok=True)
|
| 71 |
+
|
| 72 |
+
print(f"Listing files in folder {folder_id}...")
|
| 73 |
+
files_with_paths = list_files_with_paths(drive, folder_id)
|
| 74 |
+
total = len(files_with_paths)
|
| 75 |
+
print(f"Found {total} files. Planning downloads...")
|
| 76 |
+
|
| 77 |
+
# Prepare tasks and skip already downloaded files by size
|
| 78 |
+
tasks = []
|
| 79 |
+
skipped = 0
|
| 80 |
+
for meta in files_with_paths:
|
| 81 |
+
out_path = os.path.join(dest, meta["rel_path"])
|
| 82 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
| 83 |
+
if meta["size"] > 0 and os.path.exists(out_path) and os.path.getsize(out_path) == meta["size"]:
|
| 84 |
+
skipped += 1
|
| 85 |
+
continue
|
| 86 |
+
tasks.append((meta["id"], out_path))
|
| 87 |
+
|
| 88 |
+
print(f"Skipping {skipped} existing files; {len(tasks)} to download.")
|
| 89 |
+
|
| 90 |
+
def _download_one(file_id: str, out_path: str):
|
| 91 |
+
d = _get_thread_drive(service_account_json)
|
| 92 |
+
f = d.CreateFile({"id": file_id})
|
| 93 |
+
f.GetContentFile(out_path)
|
| 94 |
+
|
| 95 |
+
if len(tasks) == 0:
|
| 96 |
+
print("All files are up to date.")
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
with ThreadPoolExecutor(max_workers=workers) as ex:
|
| 100 |
+
futures = [ex.submit(_download_one, fid, path) for fid, path in tasks]
|
| 101 |
+
for _ in tqdm(as_completed(futures), total=len(futures), desc="Downloading", unit="file"):
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def pull(args=None):
|
| 106 |
+
parser = argparse.ArgumentParser(
|
| 107 |
+
description="Download a full Google Drive folder using a service account"
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--folder-id",
|
| 111 |
+
dest="folder_id",
|
| 112 |
+
default="1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p",
|
| 113 |
+
help="Google Drive folder ID",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--output-dir",
|
| 117 |
+
dest="output_dir",
|
| 118 |
+
default="dataset/",
|
| 119 |
+
help="Directory to save files",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--service-account",
|
| 123 |
+
default="secrets/drive-json.json",
|
| 124 |
+
help="Path to your Google service account JSON key file",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--workers",
|
| 128 |
+
type=int,
|
| 129 |
+
default=8,
|
| 130 |
+
help="Number of parallel download workers",
|
| 131 |
+
)
|
| 132 |
+
parsed = parser.parse_args(args=args)
|
| 133 |
+
|
| 134 |
+
download_folder(
|
| 135 |
+
parsed.folder_id, parsed.output_dir, parsed.service_account, parsed.workers
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _index_numeric_pairs(images_dir: Path, masks_dir: Path):
|
| 140 |
+
assert images_dir.exists() and images_dir.is_dir(), f"Missing images_dir: {images_dir}"
|
| 141 |
+
assert masks_dir.exists() and masks_dir.is_dir(), f"Missing masks_dir: {masks_dir}"
|
| 142 |
+
img_files = sorted([p for p in images_dir.glob("*.jpg") if p.is_file()])
|
| 143 |
+
img_files += sorted([p for p in images_dir.glob("*.jpeg") if p.is_file()])
|
| 144 |
+
assert len(img_files) > 0, f"No .jpg/.jpeg images in {images_dir}"
|
| 145 |
+
ids = []
|
| 146 |
+
for p in img_files:
|
| 147 |
+
stem = p.stem
|
| 148 |
+
assert stem.isdigit(), f"Non-numeric filename encountered: {p.name}"
|
| 149 |
+
ids.append(int(stem))
|
| 150 |
+
ids = sorted(ids)
|
| 151 |
+
pairs = []
|
| 152 |
+
for i in ids:
|
| 153 |
+
ip_jpg = images_dir / f"{i}.jpg"
|
| 154 |
+
ip_jpeg = images_dir / f"{i}.jpeg"
|
| 155 |
+
ip = ip_jpg if ip_jpg.exists() else ip_jpeg
|
| 156 |
+
assert ip.exists(), f"Missing image for {i}: {ip_jpg} or {ip_jpeg}"
|
| 157 |
+
mp = masks_dir / f"{i}.png"
|
| 158 |
+
assert mp.exists(), f"Missing mask for {i}: {mp}"
|
| 159 |
+
pairs.append((ip, mp))
|
| 160 |
+
assert len(pairs) > 0, "No numeric pairs found"
|
| 161 |
+
return pairs
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def split_test_train_val(args=None):
|
| 165 |
+
parser = argparse.ArgumentParser(
|
| 166 |
+
description="Split dataset into train/val/test = 85/5/10 with numeric pairs"
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument("--images-dir", required=True, help="Path to images directory")
|
| 169 |
+
parser.add_argument("--masks-dir", required=True, help="Path to masks directory")
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--out-dir",
|
| 172 |
+
required=True,
|
| 173 |
+
help="Output root dir where train/ val/ test/ will be created",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--link-method",
|
| 178 |
+
choices=["symlink", "copy"],
|
| 179 |
+
default="symlink",
|
| 180 |
+
help="How to place files into splits",
|
| 181 |
+
)
|
| 182 |
+
parsed = parser.parse_args(args=args)
|
| 183 |
+
|
| 184 |
+
images_dir = Path(parsed.images_dir)
|
| 185 |
+
masks_dir = Path(parsed.masks_dir)
|
| 186 |
+
out_root = Path(parsed.out_dir)
|
| 187 |
+
pairs = _index_numeric_pairs(images_dir, masks_dir)
|
| 188 |
+
|
| 189 |
+
n = len(pairs)
|
| 190 |
+
n_train = int(0.85 * n)
|
| 191 |
+
n_val = int(0.05 * n)
|
| 192 |
+
rng = random.Random(parsed.seed)
|
| 193 |
+
idxs = list(range(n))
|
| 194 |
+
rng.shuffle(idxs)
|
| 195 |
+
train_idx = idxs[:n_train]
|
| 196 |
+
val_idx = idxs[n_train : n_train + n_val]
|
| 197 |
+
test_idx = idxs[n_train + n_val :]
|
| 198 |
+
|
| 199 |
+
def _ensure_dirs(root: Path):
|
| 200 |
+
(root / "images").mkdir(parents=True, exist_ok=True)
|
| 201 |
+
(root / "gts").mkdir(parents=True, exist_ok=True)
|
| 202 |
+
|
| 203 |
+
def _place(src: Path, dst: Path):
|
| 204 |
+
if parsed.link_method == "symlink":
|
| 205 |
+
try:
|
| 206 |
+
if dst.exists() or dst.is_symlink():
|
| 207 |
+
dst.unlink()
|
| 208 |
+
os.symlink(src, dst)
|
| 209 |
+
except FileExistsError:
|
| 210 |
+
pass
|
| 211 |
+
else: # copy
|
| 212 |
+
if dst.exists():
|
| 213 |
+
dst.unlink()
|
| 214 |
+
# use hardlink if possible to be fast and space efficient
|
| 215 |
+
try:
|
| 216 |
+
os.link(src, dst)
|
| 217 |
+
except OSError:
|
| 218 |
+
import shutil
|
| 219 |
+
|
| 220 |
+
shutil.copy2(src, dst)
|
| 221 |
+
|
| 222 |
+
for split_name, split_ids in (
|
| 223 |
+
("train", train_idx),
|
| 224 |
+
("val", val_idx),
|
| 225 |
+
("test", test_idx),
|
| 226 |
+
):
|
| 227 |
+
root = out_root / split_name
|
| 228 |
+
_ensure_dirs(root)
|
| 229 |
+
for k in split_ids:
|
| 230 |
+
img_p, mask_p = pairs[k]
|
| 231 |
+
(root / "images" / img_p.name).parent.mkdir(parents=True, exist_ok=True)
|
| 232 |
+
(root / "gts" / mask_p.name).parent.mkdir(parents=True, exist_ok=True)
|
| 233 |
+
_place(img_p, root / "images" / img_p.name)
|
| 234 |
+
_place(mask_p, root / "gts" / mask_p.name)
|
| 235 |
+
print(
|
| 236 |
+
f"Split written to {out_root} | train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
# also, mkdir -p dataset/
|
| 242 |
+
path = Path("./dataset")
|
| 243 |
+
path.mkdir(exist_ok=True)
|
| 244 |
+
|
| 245 |
+
# Subcommands
|
| 246 |
+
top = argparse.ArgumentParser(description="WireSegHR data utilities")
|
| 247 |
+
subs = top.add_subparsers(dest="cmd", required=True)
|
| 248 |
+
|
| 249 |
+
sp_pull = subs.add_parser("pull", help="Download dataset from Google Drive")
|
| 250 |
+
sp_pull.add_argument("--folder-id", dest="folder_id", default="1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p")
|
| 251 |
+
sp_pull.add_argument("--output-dir", dest="output_dir", default="dataset/")
|
| 252 |
+
sp_pull.add_argument("--service-account", default="secrets/drive-json.json")
|
| 253 |
+
sp_pull.add_argument("--workers", type=int, default=8)
|
| 254 |
+
|
| 255 |
+
sp_split = subs.add_parser(
|
| 256 |
+
"split_test_train_val", help="Create 85/5/10 train/val/test split"
|
| 257 |
+
)
|
| 258 |
+
sp_split.add_argument("--images-dir", required=True)
|
| 259 |
+
sp_split.add_argument("--masks-dir", required=True)
|
| 260 |
+
sp_split.add_argument("--out-dir", required=True)
|
| 261 |
+
sp_split.add_argument("--seed", type=int, default=42)
|
| 262 |
+
sp_split.add_argument(
|
| 263 |
+
"--link-method", choices=["symlink", "copy"], default="symlink"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
ns = top.parse_args()
|
| 267 |
+
if ns.cmd == "pull":
|
| 268 |
+
pull([
|
| 269 |
+
"--folder-id",
|
| 270 |
+
ns.folder_id,
|
| 271 |
+
"--output-dir",
|
| 272 |
+
ns.output_dir,
|
| 273 |
+
"--service-account",
|
| 274 |
+
ns.service_account,
|
| 275 |
+
"--workers",
|
| 276 |
+
str(ns.workers),
|
| 277 |
+
])
|
| 278 |
+
elif ns.cmd == "split_test_train_val":
|
| 279 |
+
split_test_train_val([
|
| 280 |
+
"--images-dir",
|
| 281 |
+
ns.images_dir,
|
| 282 |
+
"--masks-dir",
|
| 283 |
+
ns.masks_dir,
|
| 284 |
+
"--out-dir",
|
| 285 |
+
ns.out_dir,
|
| 286 |
+
"--seed",
|
| 287 |
+
str(ns.seed),
|
| 288 |
+
"--link-method",
|
| 289 |
+
ns.link_method,
|
| 290 |
+
])
|
scripts/pull_ttpla.sh
CHANGED
|
File without changes
|
scripts/setup_script.sh
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
# This script downloads WireSegHR and TTPLA, converts TTPLA to masks, combines both,
|
| 4 |
+
# and creates an 85/5/10 train/val/test split under dataset/.
|
| 5 |
+
|
| 6 |
+
# 0) Setup env (includes gdown used by scripts/pull_ttpla.sh)
|
| 7 |
+
pip install uv
|
| 8 |
+
uv venv || true
|
| 9 |
+
source .venv/bin/activate
|
| 10 |
+
pip install uv
|
| 11 |
+
uv pip install -r requirements.txt
|
| 12 |
+
uv pip install gdown
|
| 13 |
+
|
| 14 |
+
# 1) Pull WireSegHR dataset from Google Drive (default folder-id provided in script)
|
| 15 |
+
# This writes under dataset/wireseghr_raw/ (adjust if you want another dir)
|
| 16 |
+
python3 scripts/pull_and_preprocess_wireseghr_dataset.py pull \
|
| 17 |
+
--output-dir dataset/wireseghr_raw
|
| 18 |
+
|
| 19 |
+
# 2) Pull TTPLA dataset zip and unzip under dataset/ttpla_dataset/
|
| 20 |
+
# Pass OUT_DIR explicitly to avoid nested dataset/dataset/ttpla_dataset
|
| 21 |
+
bash scripts/pull_ttpla.sh "" "" ttpla_dataset
|
| 22 |
+
|
| 23 |
+
# 3) Convert TTPLA JSON annotations to binary masks with numeric-only filenames
|
| 24 |
+
# Set these two to your actual TTPLA paths (after unzip).
|
| 25 |
+
TTPLA_JSON_ROOT="dataset/ttpla_dataset" # directory containing LabelMe-style JSONs (recursively)
|
| 26 |
+
mkdir -p dataset/ttpla_flat/gts
|
| 27 |
+
python3 scripts/ttpla_to_masks.py \
|
| 28 |
+
--input "$TTPLA_JSON_ROOT" \
|
| 29 |
+
--output dataset/ttpla_flat/gts \
|
| 30 |
+
--label cable
|
| 31 |
+
|
| 32 |
+
# 4) Flatten TTPLA images to numeric-only stems to match the masks
|
| 33 |
+
# Set TTPLA_IMG_ROOT to the folder under which all TTPLA images can be found (recursively).
|
| 34 |
+
TTPLA_IMG_ROOT="dataset/ttpla_dataset" # directory where the images referenced by JSONs reside (recursively)
|
| 35 |
+
mkdir -p dataset/ttpla_flat/images
|
| 36 |
+
python3 - <<'PY'
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
import json, os, shutil
|
| 39 |
+
ttpla_json_root = Path("dataset/ttpla_dataset")
|
| 40 |
+
img_root = Path(os.environ.get("TTPLA_IMG_ROOT","dataset/ttpla_dataset"))
|
| 41 |
+
out_img = Path("dataset/ttpla_flat/images")
|
| 42 |
+
out_img.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
jsons = sorted(ttpla_json_root.rglob("*.json"))
|
| 45 |
+
assert len(jsons) > 0, f"No JSONs under {ttpla_json_root}"
|
| 46 |
+
for jp in jsons:
|
| 47 |
+
data = json.loads(jp.read_text())
|
| 48 |
+
image_path = Path(data["imagePath"]) # e.g. "1_00186.jpg"
|
| 49 |
+
stem_raw = image_path.stem
|
| 50 |
+
num = "".join([c for c in stem_raw if c.isdigit()])
|
| 51 |
+
assert num.isdigit() and len(num) > 0, f"Non-numeric from {stem_raw}"
|
| 52 |
+
# locate the actual image file somewhere under img_root by filename
|
| 53 |
+
cands = list(img_root.rglob(image_path.name))
|
| 54 |
+
assert len(cands) == 1, f"Ambiguous or missing image for {image_path.name}: {cands}"
|
| 55 |
+
src = cands[0]
|
| 56 |
+
ext = src.suffix.lower() # keep original .jpg/.jpeg
|
| 57 |
+
dst = out_img / f"{num}{ext}"
|
| 58 |
+
if dst.exists() or dst.is_symlink():
|
| 59 |
+
dst.unlink()
|
| 60 |
+
# Prefer hardlink for speed and space efficiency; fallback to copy
|
| 61 |
+
try:
|
| 62 |
+
os.link(src, dst)
|
| 63 |
+
except OSError:
|
| 64 |
+
shutil.copy2(src, dst)
|
| 65 |
+
print(f"TTPLA flat images written to: {out_img}")
|
| 66 |
+
PY
|
| 67 |
+
|
| 68 |
+
# 5) Point to WireSegHR raw images/masks (adjust these to match what was downloaded in step 1)
|
| 69 |
+
# After the Drive pull, inspect to find these two folders:
|
| 70 |
+
# They must contain numeric-only image stems (.jpg/.jpeg) and PNG masks.
|
| 71 |
+
# Example placeholders below — update them to your actual locations:
|
| 72 |
+
export WSHR_IMAGES="dataset/wireseghr_raw/images"
|
| 73 |
+
export WSHR_MASKS="dataset/wireseghr_raw/gts"
|
| 74 |
+
|
| 75 |
+
# 6) Build a combined pool (WireSegHR + TTPLA) and reindex to a single contiguous numeric ID space
|
| 76 |
+
mkdir -p dataset/combined_pool_fix/images dataset/combined_pool_fix/gts
|
| 77 |
+
python3 - <<'PY'
|
| 78 |
+
import os
|
| 79 |
+
from pathlib import Path
|
| 80 |
+
|
| 81 |
+
def index_pairs(images_dir: Path, masks_dir: Path):
|
| 82 |
+
imgs = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.jpeg"))
|
| 83 |
+
pairs = {}
|
| 84 |
+
for ip in imgs:
|
| 85 |
+
assert ip.stem.isdigit(), f"Non-numeric image name: {ip.name}"
|
| 86 |
+
mp = masks_dir / f"{ip.stem}.png"
|
| 87 |
+
assert mp.exists(), f"Missing mask for {ip.stem}: {mp}"
|
| 88 |
+
pairs[int(ip.stem)] = (ip, mp)
|
| 89 |
+
return [pairs[k] for k in sorted(pairs.keys())]
|
| 90 |
+
|
| 91 |
+
w_images = Path(os.environ["WSHR_IMAGES"])
|
| 92 |
+
w_masks = Path(os.environ["WSHR_MASKS"])
|
| 93 |
+
t_images = Path("dataset/ttpla_flat/images")
|
| 94 |
+
t_masks = Path("dataset/ttpla_flat/gts")
|
| 95 |
+
|
| 96 |
+
w_pairs = index_pairs(w_images, w_masks)
|
| 97 |
+
t_pairs = index_pairs(t_images, t_masks)
|
| 98 |
+
print("w_pairs:", len(w_pairs), "t_pairs:", len(t_pairs))
|
| 99 |
+
|
| 100 |
+
all_pairs = w_pairs + t_pairs # deterministic order: WireSegHR first, then TTPLA
|
| 101 |
+
out_img = Path("dataset/combined_pool_fix/images")
|
| 102 |
+
out_msk = Path("dataset/combined_pool_fix/gts")
|
| 103 |
+
out_img.mkdir(parents=True, exist_ok=True)
|
| 104 |
+
out_msk.mkdir(parents=True, exist_ok=True)
|
| 105 |
+
|
| 106 |
+
# Reindex to 1..N, preserving each image's original extension
|
| 107 |
+
i = 1
|
| 108 |
+
for ip, mp in all_pairs:
|
| 109 |
+
ext = ip.suffix.lower() # .jpg or .jpeg
|
| 110 |
+
dst_i = out_img / f"{i}{ext}"
|
| 111 |
+
dst_m = out_msk / f"{i}.png"
|
| 112 |
+
if dst_i.exists() or dst_i.is_symlink(): dst_i.unlink()
|
| 113 |
+
if dst_m.exists() or dst_m.is_symlink(): dst_m.unlink()
|
| 114 |
+
# Prefer hardlinks; fallback to copy if cross-device or unsupported
|
| 115 |
+
try:
|
| 116 |
+
os.link(ip, dst_i)
|
| 117 |
+
except OSError:
|
| 118 |
+
import shutil; shutil.copy2(ip, dst_i)
|
| 119 |
+
try:
|
| 120 |
+
os.link(mp, dst_m)
|
| 121 |
+
except OSError:
|
| 122 |
+
import shutil; shutil.copy2(mp, dst_m)
|
| 123 |
+
i += 1
|
| 124 |
+
|
| 125 |
+
print(f"Combined pool: {i-1} pairs -> {out_img} and {out_msk}")
|
| 126 |
+
PY
|
| 127 |
+
|
| 128 |
+
# 7) Split the combined pool into train/val/test = 85/5/10
|
| 129 |
+
python3 scripts/pull_and_preprocess_wireseghr_dataset.py split_test_train_val \
|
| 130 |
+
--images-dir dataset/combined_pool_fix/images \
|
| 131 |
+
--masks-dir dataset/combined_pool_fix/gts \
|
| 132 |
+
--out-dir dataset \
|
| 133 |
+
--seed 42 \
|
| 134 |
+
--link-method copy
|
| 135 |
+
|
| 136 |
+
# Done. Your config at configs/default.yaml already points to dataset/train|val|test.
|
scripts/ttpla_to_masks.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Ensure local package under src/ is importable when running this script directly
|
| 6 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 7 |
+
SRC_PATH = PROJECT_ROOT / "src"
|
| 8 |
+
if str(SRC_PATH) not in sys.path:
|
| 9 |
+
sys.path.insert(0, str(SRC_PATH))
|
| 10 |
+
|
| 11 |
+
from wireseghr.data.ttpla_to_masks import main
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
main()
|
src/wireseghr/data/dataset.py
CHANGED
|
@@ -46,8 +46,8 @@ class WireSegDataset:
|
|
| 46 |
|
| 47 |
def _index_pairs(self) -> List[Tuple[Path, Path]]:
|
| 48 |
# Convention: numeric filenames; images are .jpg/.jpeg; masks (gts) are .png
|
| 49 |
-
img_files = sorted([p for p in self.images_dir.glob("*.jpg") if p.
|
| 50 |
-
img_files += sorted([p for p in self.images_dir.glob("*.jpeg") if p.
|
| 51 |
assert len(img_files) > 0, f"No .jpg/.jpeg images in {self.images_dir}"
|
| 52 |
pairs: List[Tuple[Path, Path]] = []
|
| 53 |
ids: List[int] = []
|
|
|
|
| 46 |
|
| 47 |
def _index_pairs(self) -> List[Tuple[Path, Path]]:
|
| 48 |
# Convention: numeric filenames; images are .jpg/.jpeg; masks (gts) are .png
|
| 49 |
+
img_files = sorted([p for p in self.images_dir.glob("*.jpg") if p.exists()])
|
| 50 |
+
img_files += sorted([p for p in self.images_dir.glob("*.jpeg") if p.exists()])
|
| 51 |
assert len(img_files) > 0, f"No .jpg/.jpeg images in {self.images_dir}"
|
| 52 |
pairs: List[Tuple[Path, Path]] = []
|
| 53 |
ids: List[int] = []
|
src/wireseghr/data/ttpla_to_masks.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Iterable, List
|
| 7 |
+
|
| 8 |
+
from PIL import Image, ImageDraw
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _rasterize_cable_mask(shapes: List[dict], height: int, width: int, label: str) -> np.ndarray:
|
| 13 |
+
"""Rasterize polygons with given label into a binary mask of shape (H, W), values {0,255}.
|
| 14 |
+
|
| 15 |
+
Expects LabelMe-style annotations with shape entries containing keys:
|
| 16 |
+
- label: str
|
| 17 |
+
- shape_type: "polygon"
|
| 18 |
+
- points: [[x,y], ...]
|
| 19 |
+
"""
|
| 20 |
+
assert height > 0 and width > 0
|
| 21 |
+
# PIL uses (W, H) for image size
|
| 22 |
+
mask_img = Image.new("L", (width, height), 0)
|
| 23 |
+
draw = ImageDraw.Draw(mask_img)
|
| 24 |
+
|
| 25 |
+
for s in shapes:
|
| 26 |
+
if s.get("label") != label:
|
| 27 |
+
continue
|
| 28 |
+
assert s.get("shape_type") == "polygon", "Only polygon shapes are supported"
|
| 29 |
+
pts = np.asarray(s.get("points"), dtype=np.float32)
|
| 30 |
+
assert pts.ndim == 2 and pts.shape[1] == 2, "Invalid points array"
|
| 31 |
+
# Round to nearest pixel and clip to image bounds
|
| 32 |
+
pts = np.rint(pts)
|
| 33 |
+
pts[:, 0] = np.clip(pts[:, 0], 0, width - 1)
|
| 34 |
+
pts[:, 1] = np.clip(pts[:, 1], 0, height - 1)
|
| 35 |
+
# PIL expects list of (x, y) tuples
|
| 36 |
+
pts_list = [ (int(p[0]), int(p[1])) for p in pts ]
|
| 37 |
+
draw.polygon(pts_list, outline=255, fill=255)
|
| 38 |
+
|
| 39 |
+
mask = np.asarray(mask_img, dtype=np.uint8)
|
| 40 |
+
return mask
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _convert_one(json_path: Path, out_dir: Path, label: str) -> Path | None:
|
| 44 |
+
with open(json_path, "r") as f:
|
| 45 |
+
data = json.load(f)
|
| 46 |
+
|
| 47 |
+
shapes = data["shapes"]
|
| 48 |
+
H = int(data["imageHeight"]) # required by given JSON
|
| 49 |
+
W = int(data["imageWidth"]) # required by given JSON
|
| 50 |
+
image_path = Path(data["imagePath"]) # e.g. "1_00186.jpg"
|
| 51 |
+
# WireSegDataset expects numeric filename stems. Derive a numeric-only stem.
|
| 52 |
+
stem_raw = image_path.stem
|
| 53 |
+
out_stem = "".join([c for c in stem_raw if c.isdigit()])
|
| 54 |
+
assert out_stem.isdigit() and len(out_stem) > 0, f"Non-numeric stem derived from {stem_raw}"
|
| 55 |
+
|
| 56 |
+
mask = _rasterize_cable_mask(shapes, H, W, label)
|
| 57 |
+
|
| 58 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
out_path = out_dir / f"{out_stem}.png"
|
| 60 |
+
# Write with Pillow
|
| 61 |
+
Image.fromarray(mask, mode="L").save(str(out_path))
|
| 62 |
+
return out_path
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def convert_ttpla_jsons_to_masks(input_path: str | Path, output_dir: str | Path, label: str = "cable", recursive: bool = True) -> List[Path]:
|
| 66 |
+
"""Convert TTPLA LabelMe JSON annotations into binary masks matching WireSegHR conventions.
|
| 67 |
+
|
| 68 |
+
- input_path: directory containing JSONs (or a single .json file)
|
| 69 |
+
- output_dir: directory where .png masks will be written
|
| 70 |
+
- label: which label to rasterize (default: "cable")
|
| 71 |
+
- recursive: when input_path is a directory, whether to search recursively
|
| 72 |
+
|
| 73 |
+
Returns a list of written mask paths.
|
| 74 |
+
"""
|
| 75 |
+
input_p = Path(input_path)
|
| 76 |
+
output_p = Path(output_dir)
|
| 77 |
+
|
| 78 |
+
if input_p.is_file():
|
| 79 |
+
assert input_p.suffix.lower() == ".json", f"Expected a .json file, got: {input_p}"
|
| 80 |
+
out = _convert_one(input_p, output_p, label)
|
| 81 |
+
return [out] if out else []
|
| 82 |
+
|
| 83 |
+
assert input_p.is_dir(), f"Input path must be a directory or a .json file: {input_p}"
|
| 84 |
+
|
| 85 |
+
json_iter: Iterable[Path]
|
| 86 |
+
if recursive:
|
| 87 |
+
json_iter = input_p.rglob("*.json")
|
| 88 |
+
else:
|
| 89 |
+
json_iter = input_p.glob("*.json")
|
| 90 |
+
|
| 91 |
+
written: List[Path] = []
|
| 92 |
+
for jp in sorted(json_iter):
|
| 93 |
+
w = _convert_one(jp, output_p, label)
|
| 94 |
+
if w is not None:
|
| 95 |
+
written.append(w)
|
| 96 |
+
return written
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main(argv: List[str] | None = None) -> None:
|
| 100 |
+
parser = argparse.ArgumentParser(description="Convert TTPLA LabelMe JSONs to WireSegHR-style binary masks")
|
| 101 |
+
parser.add_argument("--input", required=True, help="Path to a directory of JSONs or a single JSON file")
|
| 102 |
+
parser.add_argument("--output", required=True, help="Output directory for PNG masks")
|
| 103 |
+
parser.add_argument("--label", default="cable", help="Label to rasterize (default: cable)")
|
| 104 |
+
parser.add_argument("--no-recursive", action="store_true", help="Do not search subdirectories")
|
| 105 |
+
args = parser.parse_args(argv)
|
| 106 |
+
|
| 107 |
+
convert_ttpla_jsons_to_masks(
|
| 108 |
+
args.input,
|
| 109 |
+
args.output,
|
| 110 |
+
label=args.label,
|
| 111 |
+
recursive=(not args.no_recursive),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
main()
|
tests/test_ttpla_to_masks.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from wireseghr.data.ttpla_to_masks import convert_ttpla_jsons_to_masks
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _read_dims(json_path: Path):
|
| 10 |
+
with open(json_path, "r") as f:
|
| 11 |
+
data = json.load(f)
|
| 12 |
+
return int(data["imageHeight"]), int(data["imageWidth"]), Path(data["imagePath"]).stem
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_convert_single_json_cable_only(tmp_path: Path):
|
| 16 |
+
# Use the provided example JSON at repo root
|
| 17 |
+
src_json = Path("/workspace/wire-seg-hr-impl/1_00186.json")
|
| 18 |
+
assert src_json.exists()
|
| 19 |
+
|
| 20 |
+
H, W, stem = _read_dims(src_json)
|
| 21 |
+
|
| 22 |
+
out_dir = tmp_path / "masks"
|
| 23 |
+
written = convert_ttpla_jsons_to_masks(src_json, out_dir, label="cable")
|
| 24 |
+
|
| 25 |
+
assert len(written) == 1
|
| 26 |
+
out_path = written[0]
|
| 27 |
+
# Converter writes numeric-only stems
|
| 28 |
+
expected_stem = "".join([c for c in stem if c.isdigit()])
|
| 29 |
+
assert out_path.name == f"{expected_stem}.png"
|
| 30 |
+
assert out_path.exists()
|
| 31 |
+
|
| 32 |
+
mask = np.array(Image.open(out_path).convert("L"))
|
| 33 |
+
assert mask is not None
|
| 34 |
+
assert mask.shape == (H, W)
|
| 35 |
+
assert mask.dtype == np.uint8
|
| 36 |
+
|
| 37 |
+
# Binary with values in {0,255}
|
| 38 |
+
uniq = np.unique(mask)
|
| 39 |
+
assert all(int(v) in (0, 255) for v in uniq)
|
| 40 |
+
assert (mask > 0).any(), "Expected some positive pixels for cable"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_convert_different_labels(tmp_path: Path):
|
| 44 |
+
src_json = Path("/workspace/wire-seg-hr-impl/1_00186.json")
|
| 45 |
+
assert src_json.exists()
|
| 46 |
+
|
| 47 |
+
out_dir_cable = tmp_path / "masks_cable"
|
| 48 |
+
out_dir_tower = tmp_path / "masks_tower"
|
| 49 |
+
|
| 50 |
+
written_cable = convert_ttpla_jsons_to_masks(src_json, out_dir_cable, label="cable")
|
| 51 |
+
written_tower = convert_ttpla_jsons_to_masks(src_json, out_dir_tower, label="tower_wooden")
|
| 52 |
+
|
| 53 |
+
mc = np.array(Image.open(written_cable[0]).convert("L"))
|
| 54 |
+
mt = np.array(Image.open(written_tower[0]).convert("L"))
|
| 55 |
+
|
| 56 |
+
# Both masks should have some positives and should not be identical
|
| 57 |
+
assert (mc > 0).any()
|
| 58 |
+
assert (mt > 0).any()
|
| 59 |
+
assert not np.array_equal(mc, mt)
|
train.py
CHANGED
|
@@ -15,12 +15,12 @@ import random
|
|
| 15 |
import torch.backends.cudnn as cudnn
|
| 16 |
import cv2
|
| 17 |
|
| 18 |
-
from wireseghr.model import WireSegHR
|
| 19 |
-
from wireseghr.model.minmax import MinMaxLuminance
|
| 20 |
-
from wireseghr.data.dataset import WireSegDataset
|
| 21 |
-
from wireseghr.model.label_downsample import downsample_label_maxpool
|
| 22 |
-
from wireseghr.data.sampler import BalancedPatchSampler
|
| 23 |
-
from wireseghr.metrics import compute_metrics
|
| 24 |
|
| 25 |
|
| 26 |
def main():
|
|
|
|
| 15 |
import torch.backends.cudnn as cudnn
|
| 16 |
import cv2
|
| 17 |
|
| 18 |
+
from src.wireseghr.model import WireSegHR
|
| 19 |
+
from src.wireseghr.model.minmax import MinMaxLuminance
|
| 20 |
+
from src.wireseghr.data.dataset import WireSegDataset
|
| 21 |
+
from src.wireseghr.model.label_downsample import downsample_label_maxpool
|
| 22 |
+
from src.wireseghr.data.sampler import BalancedPatchSampler
|
| 23 |
+
from src.wireseghr.metrics import compute_metrics
|
| 24 |
|
| 25 |
|
| 26 |
def main():
|