MRiabov commited on
Commit
953508f
·
1 Parent(s): e8ba7db

(devops) automatic pull and preprocess of datasets

Browse files
.windsurf/rules/executing-python-files.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ---
2
+ trigger: always_on
3
+ ---
4
+
5
+ When executing python files, use python3 instead of python because that adheres to project's venv. Additionally, if you haven't activated venv yet, you have to activate it or else the execution will fail with module not found exception.
gdrive_pull.py DELETED
@@ -1,75 +0,0 @@
1
- import os
2
- import argparse
3
- from pydrive2.auth import GoogleAuth
4
- from pydrive2.drive import GoogleDrive
5
- from tqdm import tqdm
6
- from pathlib import Path
7
-
8
-
9
- def authenticate(service_account_json):
10
- """Authenticate PyDrive2 with a service account."""
11
- gauth = GoogleAuth()
12
- # Configure PyDrive2 to use service account credentials directly
13
- gauth.settings["client_config_backend"] = "service"
14
- gauth.settings["service_config"] = {
15
- "client_json_file_path": service_account_json,
16
- # Provide the key to satisfy PyDrive2 even if not impersonating
17
- "client_user_email": "drive-bot@web-design-396514.iam.gserviceaccount.com",
18
- }
19
- gauth.ServiceAuth()
20
- drive = GoogleDrive(gauth)
21
- return drive
22
-
23
-
24
- def list_files_with_paths(drive, folder_id, prefix=""):
25
- """Recursively collect all files with their relative paths from a folder."""
26
- items = []
27
- query = f"'{folder_id}' in parents and trashed=false"
28
- for file in drive.ListFile({"q": query, "maxResults": 1000}).GetList():
29
- if file["mimeType"] == "application/vnd.google-apps.folder":
30
- sub_prefix = (
31
- os.path.join(prefix, file["title"]) if prefix else file["title"]
32
- )
33
- items += list_files_with_paths(drive, file["id"], sub_prefix)
34
- else:
35
- rel_path = os.path.join(prefix, file["title"]) if prefix else file["title"]
36
- items.append((file, rel_path))
37
- return items
38
-
39
-
40
- def download_folder(folder_id, dest, service_account_json):
41
- drive = authenticate(service_account_json)
42
- os.makedirs(dest, exist_ok=True)
43
-
44
- print(f"Listing files in folder {folder_id}...")
45
- files_with_paths = list_files_with_paths(drive, folder_id)
46
- print(f"Found {len(files_with_paths)} files. Downloading...")
47
-
48
- for file, rel_path in tqdm(files_with_paths, desc="Downloading", unit="file"):
49
- out_path = os.path.join(dest, rel_path)
50
- os.makedirs(os.path.dirname(out_path), exist_ok=True)
51
- file.GetContentFile(out_path)
52
-
53
-
54
- def main():
55
- parser = argparse.ArgumentParser(
56
- description="Download a full Google Drive folder using a service account"
57
- )
58
- parser.add_argument("folder_id", help="Google Drive folder ID")
59
- parser.add_argument("output_dir", help="Directory to save files")
60
- parser.add_argument(
61
- "--service-account",
62
- default="service_account.json",
63
- help="Path to your Google service account JSON key file",
64
- )
65
- args = parser.parse_args()
66
-
67
- download_folder(args.folder_id, args.output_dir, args.service_account)
68
-
69
-
70
- if __name__ == "__main__":
71
- # also, mkdir -p dataset/
72
- path = Path("./dataset")
73
- path.mkdir(exists_ok=True)
74
-
75
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -8,3 +8,4 @@ Pillow>=9.5.0
8
  PyYAML>=6.0.1
9
  tqdm>=4.65.0
10
  gdown>=5.1.0
 
 
8
  PyYAML>=6.0.1
9
  tqdm>=4.65.0
10
  gdown>=5.1.0
11
+ pydrive2
scripts/pull_and_preprocess_wireseghr_dataset.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import threading
4
+ import random
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from pydrive2.auth import GoogleAuth
7
+ from pydrive2.drive import GoogleDrive
8
+ from tqdm import tqdm
9
+ from pathlib import Path
10
+
11
+ thread_local = threading.local()
12
+
13
+
14
+ def _get_thread_drive(service_account_json: str) -> GoogleDrive:
15
+ d = getattr(thread_local, "drive", None)
16
+ if d is None:
17
+ d = authenticate(service_account_json)
18
+ thread_local.drive = d
19
+ return d
20
+
21
+
22
+ def authenticate(service_account_json):
23
+ """Authenticate PyDrive2 with a service account."""
24
+ gauth = GoogleAuth()
25
+ # Configure PyDrive2 to use service account credentials directly
26
+ gauth.settings["client_config_backend"] = "service"
27
+ gauth.settings["service_config"] = {
28
+ "client_json_file_path": service_account_json,
29
+ # Provide the key to satisfy PyDrive2 even if not impersonating
30
+ "client_user_email": "drive-bot@web-design-396514.iam.gserviceaccount.com",
31
+ }
32
+ gauth.ServiceAuth()
33
+ drive = GoogleDrive(gauth)
34
+ return drive
35
+
36
+
37
+ def list_files_with_paths(drive, folder_id, prefix=""):
38
+ """Recursively collect all files with their relative paths from a folder."""
39
+ items = []
40
+ query = f"'{folder_id}' in parents and trashed=false"
41
+ params = {
42
+ "q": query,
43
+ "maxResults": 1000,
44
+ # Request only needed fields (Drive API v2 uses 'items')
45
+ "fields": "items(id,title,mimeType,fileSize,md5Checksum),nextPageToken",
46
+ }
47
+ for file in drive.ListFile(params).GetList():
48
+ if file["mimeType"] == "application/vnd.google-apps.folder":
49
+ sub_prefix = (
50
+ os.path.join(prefix, file["title"]) if prefix else file["title"]
51
+ )
52
+ items += list_files_with_paths(drive, file["id"], sub_prefix)
53
+ else:
54
+ rel_path = os.path.join(prefix, file["title"]) if prefix else file["title"]
55
+ size = int(file.get("fileSize", 0)) if "fileSize" in file else 0
56
+ items.append(
57
+ {
58
+ "id": file["id"],
59
+ "rel_path": rel_path,
60
+ "size": size,
61
+ "md5": file.get("md5Checksum", ""),
62
+ "mimeType": file["mimeType"],
63
+ }
64
+ )
65
+ return items
66
+
67
+
68
+ def download_folder(folder_id, dest, service_account_json, workers: int):
69
+ drive = authenticate(service_account_json)
70
+ os.makedirs(dest, exist_ok=True)
71
+
72
+ print(f"Listing files in folder {folder_id}...")
73
+ files_with_paths = list_files_with_paths(drive, folder_id)
74
+ total = len(files_with_paths)
75
+ print(f"Found {total} files. Planning downloads...")
76
+
77
+ # Prepare tasks and skip already downloaded files by size
78
+ tasks = []
79
+ skipped = 0
80
+ for meta in files_with_paths:
81
+ out_path = os.path.join(dest, meta["rel_path"])
82
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
83
+ if meta["size"] > 0 and os.path.exists(out_path) and os.path.getsize(out_path) == meta["size"]:
84
+ skipped += 1
85
+ continue
86
+ tasks.append((meta["id"], out_path))
87
+
88
+ print(f"Skipping {skipped} existing files; {len(tasks)} to download.")
89
+
90
+ def _download_one(file_id: str, out_path: str):
91
+ d = _get_thread_drive(service_account_json)
92
+ f = d.CreateFile({"id": file_id})
93
+ f.GetContentFile(out_path)
94
+
95
+ if len(tasks) == 0:
96
+ print("All files are up to date.")
97
+ return
98
+
99
+ with ThreadPoolExecutor(max_workers=workers) as ex:
100
+ futures = [ex.submit(_download_one, fid, path) for fid, path in tasks]
101
+ for _ in tqdm(as_completed(futures), total=len(futures), desc="Downloading", unit="file"):
102
+ pass
103
+
104
+
105
+ def pull(args=None):
106
+ parser = argparse.ArgumentParser(
107
+ description="Download a full Google Drive folder using a service account"
108
+ )
109
+ parser.add_argument(
110
+ "--folder-id",
111
+ dest="folder_id",
112
+ default="1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p",
113
+ help="Google Drive folder ID",
114
+ )
115
+ parser.add_argument(
116
+ "--output-dir",
117
+ dest="output_dir",
118
+ default="dataset/",
119
+ help="Directory to save files",
120
+ )
121
+ parser.add_argument(
122
+ "--service-account",
123
+ default="secrets/drive-json.json",
124
+ help="Path to your Google service account JSON key file",
125
+ )
126
+ parser.add_argument(
127
+ "--workers",
128
+ type=int,
129
+ default=8,
130
+ help="Number of parallel download workers",
131
+ )
132
+ parsed = parser.parse_args(args=args)
133
+
134
+ download_folder(
135
+ parsed.folder_id, parsed.output_dir, parsed.service_account, parsed.workers
136
+ )
137
+
138
+
139
+ def _index_numeric_pairs(images_dir: Path, masks_dir: Path):
140
+ assert images_dir.exists() and images_dir.is_dir(), f"Missing images_dir: {images_dir}"
141
+ assert masks_dir.exists() and masks_dir.is_dir(), f"Missing masks_dir: {masks_dir}"
142
+ img_files = sorted([p for p in images_dir.glob("*.jpg") if p.is_file()])
143
+ img_files += sorted([p for p in images_dir.glob("*.jpeg") if p.is_file()])
144
+ assert len(img_files) > 0, f"No .jpg/.jpeg images in {images_dir}"
145
+ ids = []
146
+ for p in img_files:
147
+ stem = p.stem
148
+ assert stem.isdigit(), f"Non-numeric filename encountered: {p.name}"
149
+ ids.append(int(stem))
150
+ ids = sorted(ids)
151
+ pairs = []
152
+ for i in ids:
153
+ ip_jpg = images_dir / f"{i}.jpg"
154
+ ip_jpeg = images_dir / f"{i}.jpeg"
155
+ ip = ip_jpg if ip_jpg.exists() else ip_jpeg
156
+ assert ip.exists(), f"Missing image for {i}: {ip_jpg} or {ip_jpeg}"
157
+ mp = masks_dir / f"{i}.png"
158
+ assert mp.exists(), f"Missing mask for {i}: {mp}"
159
+ pairs.append((ip, mp))
160
+ assert len(pairs) > 0, "No numeric pairs found"
161
+ return pairs
162
+
163
+
164
+ def split_test_train_val(args=None):
165
+ parser = argparse.ArgumentParser(
166
+ description="Split dataset into train/val/test = 85/5/10 with numeric pairs"
167
+ )
168
+ parser.add_argument("--images-dir", required=True, help="Path to images directory")
169
+ parser.add_argument("--masks-dir", required=True, help="Path to masks directory")
170
+ parser.add_argument(
171
+ "--out-dir",
172
+ required=True,
173
+ help="Output root dir where train/ val/ test/ will be created",
174
+ )
175
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
176
+ parser.add_argument(
177
+ "--link-method",
178
+ choices=["symlink", "copy"],
179
+ default="symlink",
180
+ help="How to place files into splits",
181
+ )
182
+ parsed = parser.parse_args(args=args)
183
+
184
+ images_dir = Path(parsed.images_dir)
185
+ masks_dir = Path(parsed.masks_dir)
186
+ out_root = Path(parsed.out_dir)
187
+ pairs = _index_numeric_pairs(images_dir, masks_dir)
188
+
189
+ n = len(pairs)
190
+ n_train = int(0.85 * n)
191
+ n_val = int(0.05 * n)
192
+ rng = random.Random(parsed.seed)
193
+ idxs = list(range(n))
194
+ rng.shuffle(idxs)
195
+ train_idx = idxs[:n_train]
196
+ val_idx = idxs[n_train : n_train + n_val]
197
+ test_idx = idxs[n_train + n_val :]
198
+
199
+ def _ensure_dirs(root: Path):
200
+ (root / "images").mkdir(parents=True, exist_ok=True)
201
+ (root / "gts").mkdir(parents=True, exist_ok=True)
202
+
203
+ def _place(src: Path, dst: Path):
204
+ if parsed.link_method == "symlink":
205
+ try:
206
+ if dst.exists() or dst.is_symlink():
207
+ dst.unlink()
208
+ os.symlink(src, dst)
209
+ except FileExistsError:
210
+ pass
211
+ else: # copy
212
+ if dst.exists():
213
+ dst.unlink()
214
+ # use hardlink if possible to be fast and space efficient
215
+ try:
216
+ os.link(src, dst)
217
+ except OSError:
218
+ import shutil
219
+
220
+ shutil.copy2(src, dst)
221
+
222
+ for split_name, split_ids in (
223
+ ("train", train_idx),
224
+ ("val", val_idx),
225
+ ("test", test_idx),
226
+ ):
227
+ root = out_root / split_name
228
+ _ensure_dirs(root)
229
+ for k in split_ids:
230
+ img_p, mask_p = pairs[k]
231
+ (root / "images" / img_p.name).parent.mkdir(parents=True, exist_ok=True)
232
+ (root / "gts" / mask_p.name).parent.mkdir(parents=True, exist_ok=True)
233
+ _place(img_p, root / "images" / img_p.name)
234
+ _place(mask_p, root / "gts" / mask_p.name)
235
+ print(
236
+ f"Split written to {out_root} | train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}"
237
+ )
238
+
239
+
240
+ if __name__ == "__main__":
241
+ # also, mkdir -p dataset/
242
+ path = Path("./dataset")
243
+ path.mkdir(exist_ok=True)
244
+
245
+ # Subcommands
246
+ top = argparse.ArgumentParser(description="WireSegHR data utilities")
247
+ subs = top.add_subparsers(dest="cmd", required=True)
248
+
249
+ sp_pull = subs.add_parser("pull", help="Download dataset from Google Drive")
250
+ sp_pull.add_argument("--folder-id", dest="folder_id", default="1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p")
251
+ sp_pull.add_argument("--output-dir", dest="output_dir", default="dataset/")
252
+ sp_pull.add_argument("--service-account", default="secrets/drive-json.json")
253
+ sp_pull.add_argument("--workers", type=int, default=8)
254
+
255
+ sp_split = subs.add_parser(
256
+ "split_test_train_val", help="Create 85/5/10 train/val/test split"
257
+ )
258
+ sp_split.add_argument("--images-dir", required=True)
259
+ sp_split.add_argument("--masks-dir", required=True)
260
+ sp_split.add_argument("--out-dir", required=True)
261
+ sp_split.add_argument("--seed", type=int, default=42)
262
+ sp_split.add_argument(
263
+ "--link-method", choices=["symlink", "copy"], default="symlink"
264
+ )
265
+
266
+ ns = top.parse_args()
267
+ if ns.cmd == "pull":
268
+ pull([
269
+ "--folder-id",
270
+ ns.folder_id,
271
+ "--output-dir",
272
+ ns.output_dir,
273
+ "--service-account",
274
+ ns.service_account,
275
+ "--workers",
276
+ str(ns.workers),
277
+ ])
278
+ elif ns.cmd == "split_test_train_val":
279
+ split_test_train_val([
280
+ "--images-dir",
281
+ ns.images_dir,
282
+ "--masks-dir",
283
+ ns.masks_dir,
284
+ "--out-dir",
285
+ ns.out_dir,
286
+ "--seed",
287
+ str(ns.seed),
288
+ "--link-method",
289
+ ns.link_method,
290
+ ])
scripts/pull_ttpla.sh CHANGED
File without changes
scripts/setup_script.sh ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ # This script downloads WireSegHR and TTPLA, converts TTPLA to masks, combines both,
4
+ # and creates an 85/5/10 train/val/test split under dataset/.
5
+
6
+ # 0) Setup env (includes gdown used by scripts/pull_ttpla.sh)
7
+ pip install uv
8
+ uv venv || true
9
+ source .venv/bin/activate
10
+ pip install uv
11
+ uv pip install -r requirements.txt
12
+ uv pip install gdown
13
+
14
+ # 1) Pull WireSegHR dataset from Google Drive (default folder-id provided in script)
15
+ # This writes under dataset/wireseghr_raw/ (adjust if you want another dir)
16
+ python3 scripts/pull_and_preprocess_wireseghr_dataset.py pull \
17
+ --output-dir dataset/wireseghr_raw
18
+
19
+ # 2) Pull TTPLA dataset zip and unzip under dataset/ttpla_dataset/
20
+ # Pass OUT_DIR explicitly to avoid nested dataset/dataset/ttpla_dataset
21
+ bash scripts/pull_ttpla.sh "" "" ttpla_dataset
22
+
23
+ # 3) Convert TTPLA JSON annotations to binary masks with numeric-only filenames
24
+ # Set these two to your actual TTPLA paths (after unzip).
25
+ TTPLA_JSON_ROOT="dataset/ttpla_dataset" # directory containing LabelMe-style JSONs (recursively)
26
+ mkdir -p dataset/ttpla_flat/gts
27
+ python3 scripts/ttpla_to_masks.py \
28
+ --input "$TTPLA_JSON_ROOT" \
29
+ --output dataset/ttpla_flat/gts \
30
+ --label cable
31
+
32
+ # 4) Flatten TTPLA images to numeric-only stems to match the masks
33
+ # Set TTPLA_IMG_ROOT to the folder under which all TTPLA images can be found (recursively).
34
+ TTPLA_IMG_ROOT="dataset/ttpla_dataset" # directory where the images referenced by JSONs reside (recursively)
35
+ mkdir -p dataset/ttpla_flat/images
36
+ python3 - <<'PY'
37
+ from pathlib import Path
38
+ import json, os, shutil
39
+ ttpla_json_root = Path("dataset/ttpla_dataset")
40
+ img_root = Path(os.environ.get("TTPLA_IMG_ROOT","dataset/ttpla_dataset"))
41
+ out_img = Path("dataset/ttpla_flat/images")
42
+ out_img.mkdir(parents=True, exist_ok=True)
43
+
44
+ jsons = sorted(ttpla_json_root.rglob("*.json"))
45
+ assert len(jsons) > 0, f"No JSONs under {ttpla_json_root}"
46
+ for jp in jsons:
47
+ data = json.loads(jp.read_text())
48
+ image_path = Path(data["imagePath"]) # e.g. "1_00186.jpg"
49
+ stem_raw = image_path.stem
50
+ num = "".join([c for c in stem_raw if c.isdigit()])
51
+ assert num.isdigit() and len(num) > 0, f"Non-numeric from {stem_raw}"
52
+ # locate the actual image file somewhere under img_root by filename
53
+ cands = list(img_root.rglob(image_path.name))
54
+ assert len(cands) == 1, f"Ambiguous or missing image for {image_path.name}: {cands}"
55
+ src = cands[0]
56
+ ext = src.suffix.lower() # keep original .jpg/.jpeg
57
+ dst = out_img / f"{num}{ext}"
58
+ if dst.exists() or dst.is_symlink():
59
+ dst.unlink()
60
+ # Prefer hardlink for speed and space efficiency; fallback to copy
61
+ try:
62
+ os.link(src, dst)
63
+ except OSError:
64
+ shutil.copy2(src, dst)
65
+ print(f"TTPLA flat images written to: {out_img}")
66
+ PY
67
+
68
+ # 5) Point to WireSegHR raw images/masks (adjust these to match what was downloaded in step 1)
69
+ # After the Drive pull, inspect to find these two folders:
70
+ # They must contain numeric-only image stems (.jpg/.jpeg) and PNG masks.
71
+ # Example placeholders below — update them to your actual locations:
72
+ export WSHR_IMAGES="dataset/wireseghr_raw/images"
73
+ export WSHR_MASKS="dataset/wireseghr_raw/gts"
74
+
75
+ # 6) Build a combined pool (WireSegHR + TTPLA) and reindex to a single contiguous numeric ID space
76
+ mkdir -p dataset/combined_pool_fix/images dataset/combined_pool_fix/gts
77
+ python3 - <<'PY'
78
+ import os
79
+ from pathlib import Path
80
+
81
+ def index_pairs(images_dir: Path, masks_dir: Path):
82
+ imgs = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.jpeg"))
83
+ pairs = {}
84
+ for ip in imgs:
85
+ assert ip.stem.isdigit(), f"Non-numeric image name: {ip.name}"
86
+ mp = masks_dir / f"{ip.stem}.png"
87
+ assert mp.exists(), f"Missing mask for {ip.stem}: {mp}"
88
+ pairs[int(ip.stem)] = (ip, mp)
89
+ return [pairs[k] for k in sorted(pairs.keys())]
90
+
91
+ w_images = Path(os.environ["WSHR_IMAGES"])
92
+ w_masks = Path(os.environ["WSHR_MASKS"])
93
+ t_images = Path("dataset/ttpla_flat/images")
94
+ t_masks = Path("dataset/ttpla_flat/gts")
95
+
96
+ w_pairs = index_pairs(w_images, w_masks)
97
+ t_pairs = index_pairs(t_images, t_masks)
98
+ print("w_pairs:", len(w_pairs), "t_pairs:", len(t_pairs))
99
+
100
+ all_pairs = w_pairs + t_pairs # deterministic order: WireSegHR first, then TTPLA
101
+ out_img = Path("dataset/combined_pool_fix/images")
102
+ out_msk = Path("dataset/combined_pool_fix/gts")
103
+ out_img.mkdir(parents=True, exist_ok=True)
104
+ out_msk.mkdir(parents=True, exist_ok=True)
105
+
106
+ # Reindex to 1..N, preserving each image's original extension
107
+ i = 1
108
+ for ip, mp in all_pairs:
109
+ ext = ip.suffix.lower() # .jpg or .jpeg
110
+ dst_i = out_img / f"{i}{ext}"
111
+ dst_m = out_msk / f"{i}.png"
112
+ if dst_i.exists() or dst_i.is_symlink(): dst_i.unlink()
113
+ if dst_m.exists() or dst_m.is_symlink(): dst_m.unlink()
114
+ # Prefer hardlinks; fallback to copy if cross-device or unsupported
115
+ try:
116
+ os.link(ip, dst_i)
117
+ except OSError:
118
+ import shutil; shutil.copy2(ip, dst_i)
119
+ try:
120
+ os.link(mp, dst_m)
121
+ except OSError:
122
+ import shutil; shutil.copy2(mp, dst_m)
123
+ i += 1
124
+
125
+ print(f"Combined pool: {i-1} pairs -> {out_img} and {out_msk}")
126
+ PY
127
+
128
+ # 7) Split the combined pool into train/val/test = 85/5/10
129
+ python3 scripts/pull_and_preprocess_wireseghr_dataset.py split_test_train_val \
130
+ --images-dir dataset/combined_pool_fix/images \
131
+ --masks-dir dataset/combined_pool_fix/gts \
132
+ --out-dir dataset \
133
+ --seed 42 \
134
+ --link-method copy
135
+
136
+ # Done. Your config at configs/default.yaml already points to dataset/train|val|test.
scripts/ttpla_to_masks.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ # Ensure local package under src/ is importable when running this script directly
6
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
7
+ SRC_PATH = PROJECT_ROOT / "src"
8
+ if str(SRC_PATH) not in sys.path:
9
+ sys.path.insert(0, str(SRC_PATH))
10
+
11
+ from wireseghr.data.ttpla_to_masks import main
12
+
13
+ if __name__ == "__main__":
14
+ main()
src/wireseghr/data/dataset.py CHANGED
@@ -46,8 +46,8 @@ class WireSegDataset:
46
 
47
  def _index_pairs(self) -> List[Tuple[Path, Path]]:
48
  # Convention: numeric filenames; images are .jpg/.jpeg; masks (gts) are .png
49
- img_files = sorted([p for p in self.images_dir.glob("*.jpg") if p.is_file()])
50
- img_files += sorted([p for p in self.images_dir.glob("*.jpeg") if p.is_file()])
51
  assert len(img_files) > 0, f"No .jpg/.jpeg images in {self.images_dir}"
52
  pairs: List[Tuple[Path, Path]] = []
53
  ids: List[int] = []
 
46
 
47
  def _index_pairs(self) -> List[Tuple[Path, Path]]:
48
  # Convention: numeric filenames; images are .jpg/.jpeg; masks (gts) are .png
49
+ img_files = sorted([p for p in self.images_dir.glob("*.jpg") if p.exists()])
50
+ img_files += sorted([p for p in self.images_dir.glob("*.jpeg") if p.exists()])
51
  assert len(img_files) > 0, f"No .jpg/.jpeg images in {self.images_dir}"
52
  pairs: List[Tuple[Path, Path]] = []
53
  ids: List[int] = []
src/wireseghr/data/ttpla_to_masks.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Iterable, List
7
+
8
+ from PIL import Image, ImageDraw
9
+ import numpy as np
10
+
11
+
12
+ def _rasterize_cable_mask(shapes: List[dict], height: int, width: int, label: str) -> np.ndarray:
13
+ """Rasterize polygons with given label into a binary mask of shape (H, W), values {0,255}.
14
+
15
+ Expects LabelMe-style annotations with shape entries containing keys:
16
+ - label: str
17
+ - shape_type: "polygon"
18
+ - points: [[x,y], ...]
19
+ """
20
+ assert height > 0 and width > 0
21
+ # PIL uses (W, H) for image size
22
+ mask_img = Image.new("L", (width, height), 0)
23
+ draw = ImageDraw.Draw(mask_img)
24
+
25
+ for s in shapes:
26
+ if s.get("label") != label:
27
+ continue
28
+ assert s.get("shape_type") == "polygon", "Only polygon shapes are supported"
29
+ pts = np.asarray(s.get("points"), dtype=np.float32)
30
+ assert pts.ndim == 2 and pts.shape[1] == 2, "Invalid points array"
31
+ # Round to nearest pixel and clip to image bounds
32
+ pts = np.rint(pts)
33
+ pts[:, 0] = np.clip(pts[:, 0], 0, width - 1)
34
+ pts[:, 1] = np.clip(pts[:, 1], 0, height - 1)
35
+ # PIL expects list of (x, y) tuples
36
+ pts_list = [ (int(p[0]), int(p[1])) for p in pts ]
37
+ draw.polygon(pts_list, outline=255, fill=255)
38
+
39
+ mask = np.asarray(mask_img, dtype=np.uint8)
40
+ return mask
41
+
42
+
43
+ def _convert_one(json_path: Path, out_dir: Path, label: str) -> Path | None:
44
+ with open(json_path, "r") as f:
45
+ data = json.load(f)
46
+
47
+ shapes = data["shapes"]
48
+ H = int(data["imageHeight"]) # required by given JSON
49
+ W = int(data["imageWidth"]) # required by given JSON
50
+ image_path = Path(data["imagePath"]) # e.g. "1_00186.jpg"
51
+ # WireSegDataset expects numeric filename stems. Derive a numeric-only stem.
52
+ stem_raw = image_path.stem
53
+ out_stem = "".join([c for c in stem_raw if c.isdigit()])
54
+ assert out_stem.isdigit() and len(out_stem) > 0, f"Non-numeric stem derived from {stem_raw}"
55
+
56
+ mask = _rasterize_cable_mask(shapes, H, W, label)
57
+
58
+ out_dir.mkdir(parents=True, exist_ok=True)
59
+ out_path = out_dir / f"{out_stem}.png"
60
+ # Write with Pillow
61
+ Image.fromarray(mask, mode="L").save(str(out_path))
62
+ return out_path
63
+
64
+
65
+ def convert_ttpla_jsons_to_masks(input_path: str | Path, output_dir: str | Path, label: str = "cable", recursive: bool = True) -> List[Path]:
66
+ """Convert TTPLA LabelMe JSON annotations into binary masks matching WireSegHR conventions.
67
+
68
+ - input_path: directory containing JSONs (or a single .json file)
69
+ - output_dir: directory where .png masks will be written
70
+ - label: which label to rasterize (default: "cable")
71
+ - recursive: when input_path is a directory, whether to search recursively
72
+
73
+ Returns a list of written mask paths.
74
+ """
75
+ input_p = Path(input_path)
76
+ output_p = Path(output_dir)
77
+
78
+ if input_p.is_file():
79
+ assert input_p.suffix.lower() == ".json", f"Expected a .json file, got: {input_p}"
80
+ out = _convert_one(input_p, output_p, label)
81
+ return [out] if out else []
82
+
83
+ assert input_p.is_dir(), f"Input path must be a directory or a .json file: {input_p}"
84
+
85
+ json_iter: Iterable[Path]
86
+ if recursive:
87
+ json_iter = input_p.rglob("*.json")
88
+ else:
89
+ json_iter = input_p.glob("*.json")
90
+
91
+ written: List[Path] = []
92
+ for jp in sorted(json_iter):
93
+ w = _convert_one(jp, output_p, label)
94
+ if w is not None:
95
+ written.append(w)
96
+ return written
97
+
98
+
99
+ def main(argv: List[str] | None = None) -> None:
100
+ parser = argparse.ArgumentParser(description="Convert TTPLA LabelMe JSONs to WireSegHR-style binary masks")
101
+ parser.add_argument("--input", required=True, help="Path to a directory of JSONs or a single JSON file")
102
+ parser.add_argument("--output", required=True, help="Output directory for PNG masks")
103
+ parser.add_argument("--label", default="cable", help="Label to rasterize (default: cable)")
104
+ parser.add_argument("--no-recursive", action="store_true", help="Do not search subdirectories")
105
+ args = parser.parse_args(argv)
106
+
107
+ convert_ttpla_jsons_to_masks(
108
+ args.input,
109
+ args.output,
110
+ label=args.label,
111
+ recursive=(not args.no_recursive),
112
+ )
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
tests/test_ttpla_to_masks.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ from wireseghr.data.ttpla_to_masks import convert_ttpla_jsons_to_masks
7
+
8
+
9
+ def _read_dims(json_path: Path):
10
+ with open(json_path, "r") as f:
11
+ data = json.load(f)
12
+ return int(data["imageHeight"]), int(data["imageWidth"]), Path(data["imagePath"]).stem
13
+
14
+
15
+ def test_convert_single_json_cable_only(tmp_path: Path):
16
+ # Use the provided example JSON at repo root
17
+ src_json = Path("/workspace/wire-seg-hr-impl/1_00186.json")
18
+ assert src_json.exists()
19
+
20
+ H, W, stem = _read_dims(src_json)
21
+
22
+ out_dir = tmp_path / "masks"
23
+ written = convert_ttpla_jsons_to_masks(src_json, out_dir, label="cable")
24
+
25
+ assert len(written) == 1
26
+ out_path = written[0]
27
+ # Converter writes numeric-only stems
28
+ expected_stem = "".join([c for c in stem if c.isdigit()])
29
+ assert out_path.name == f"{expected_stem}.png"
30
+ assert out_path.exists()
31
+
32
+ mask = np.array(Image.open(out_path).convert("L"))
33
+ assert mask is not None
34
+ assert mask.shape == (H, W)
35
+ assert mask.dtype == np.uint8
36
+
37
+ # Binary with values in {0,255}
38
+ uniq = np.unique(mask)
39
+ assert all(int(v) in (0, 255) for v in uniq)
40
+ assert (mask > 0).any(), "Expected some positive pixels for cable"
41
+
42
+
43
+ def test_convert_different_labels(tmp_path: Path):
44
+ src_json = Path("/workspace/wire-seg-hr-impl/1_00186.json")
45
+ assert src_json.exists()
46
+
47
+ out_dir_cable = tmp_path / "masks_cable"
48
+ out_dir_tower = tmp_path / "masks_tower"
49
+
50
+ written_cable = convert_ttpla_jsons_to_masks(src_json, out_dir_cable, label="cable")
51
+ written_tower = convert_ttpla_jsons_to_masks(src_json, out_dir_tower, label="tower_wooden")
52
+
53
+ mc = np.array(Image.open(written_cable[0]).convert("L"))
54
+ mt = np.array(Image.open(written_tower[0]).convert("L"))
55
+
56
+ # Both masks should have some positives and should not be identical
57
+ assert (mc > 0).any()
58
+ assert (mt > 0).any()
59
+ assert not np.array_equal(mc, mt)
train.py CHANGED
@@ -15,12 +15,12 @@ import random
15
  import torch.backends.cudnn as cudnn
16
  import cv2
17
 
18
- from wireseghr.model import WireSegHR
19
- from wireseghr.model.minmax import MinMaxLuminance
20
- from wireseghr.data.dataset import WireSegDataset
21
- from wireseghr.model.label_downsample import downsample_label_maxpool
22
- from wireseghr.data.sampler import BalancedPatchSampler
23
- from wireseghr.metrics import compute_metrics
24
 
25
 
26
  def main():
 
15
  import torch.backends.cudnn as cudnn
16
  import cv2
17
 
18
+ from src.wireseghr.model import WireSegHR
19
+ from src.wireseghr.model.minmax import MinMaxLuminance
20
+ from src.wireseghr.data.dataset import WireSegDataset
21
+ from src.wireseghr.model.label_downsample import downsample_label_maxpool
22
+ from src.wireseghr.data.sampler import BalancedPatchSampler
23
+ from src.wireseghr.metrics import compute_metrics
24
 
25
 
26
  def main():