iljung1106 commited on
Commit
4f81869
·
1 Parent(s): e6ecafe

Add some scripts.

Browse files
scripts/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # This file makes `scripts/` a Python package so entrypoints can be imported reliably.
2
+
3
+
scripts/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (156 Bytes). View file
 
scripts/__pycache__/crawl_dataset.cpython-310.pyc ADDED
Binary file (613 Bytes). View file
 
scripts/__pycache__/eval_prototypes_halfval.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
scripts/__pycache__/eval_prototypes_strict_90_10.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
scripts/__pycache__/extract_faces_eyes.cpython-310.pyc ADDED
Binary file (611 Bytes). View file
 
scripts/__pycache__/train_ddp.cpython-310.pyc ADDED
Binary file (552 Bytes). View file
 
scripts/crawl_dataset.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Dataset crawling entrypoint (Danbooru artist list via Selenium + Gelbooru downloads).
5
+
6
+ This wraps `crawler_api.py` so you can run:
7
+ python scripts/crawl_dataset.py --help
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import sys
13
+ from pathlib import Path
14
+
15
+
16
+ _ROOT = Path(__file__).resolve().parents[1]
17
+ if str(_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(_ROOT))
19
+
20
+
21
+ def main() -> None:
22
+ import crawler_api
23
+
24
+ crawler_api.main_cli()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
29
+
30
+
scripts/eval_prototypes_strict_90_10.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Prototype evaluation (strict 90/10 split per view per artist, using merged train+val pools).
5
+
6
+ This script mirrors the "strict 90/10 full coverage" prototype-eval logic.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import os
13
+ import random
14
+ from collections import defaultdict
15
+ from dataclasses import dataclass
16
+ from typing import Dict, List, Optional, Sequence, Tuple
17
+
18
+ import torch
19
+ from PIL import Image, UnidentifiedImageError
20
+ from torch.utils.data import DataLoader, Dataset
21
+
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ _ROOT = Path(__file__).resolve().parents[1]
26
+ if str(_ROOT) not in sys.path:
27
+ sys.path.insert(0, str(_ROOT))
28
+
29
+ import train_style_ddp as ts
30
+
31
+
32
+ TripletWithID = Tuple[str, str, str, int]
33
+
34
+
35
+ @dataclass
36
+ class Args:
37
+ ckpt: str
38
+ out: str
39
+ k_per_artist: int
40
+ build_ratio: float
41
+ batch_size: int
42
+ num_workers: int
43
+ seed: int
44
+ chunk_size: int
45
+
46
+
47
+ def parse_args() -> Args:
48
+ p = argparse.ArgumentParser(description="Eval prototypes (strict 90/10 split per view)")
49
+ p.add_argument("--ckpt", type=str, default="./checkpoints_style/stage3_epoch24.pt")
50
+ p.add_argument("--out", type=str, default="./checkpoints_style/per_artist_prototypes_90_10_full.pt")
51
+ p.add_argument("--k-per-artist", type=int, default=4)
52
+ p.add_argument("--build-ratio", type=float, default=0.9)
53
+ p.add_argument("--batch-size", type=int, default=64)
54
+ p.add_argument("--num-workers", type=int, default=0, help="0 is safest on Windows/spawn.")
55
+ p.add_argument("--seed", type=int, default=ts.cfg.seed)
56
+ p.add_argument("--chunk-size", type=int, default=5000)
57
+ a = p.parse_args()
58
+ return Args(
59
+ ckpt=a.ckpt,
60
+ out=a.out,
61
+ k_per_artist=a.k_per_artist,
62
+ build_ratio=a.build_ratio,
63
+ batch_size=a.batch_size,
64
+ num_workers=a.num_workers,
65
+ seed=a.seed,
66
+ chunk_size=a.chunk_size,
67
+ )
68
+
69
+
70
+ def kmeans_cosine(Z_cpu: torch.Tensor, K: int, *, iters: int = 20, seed: int = 1337, device: torch.device) -> torch.Tensor:
71
+ Z = torch.nn.functional.normalize(Z_cpu.to(device), dim=1)
72
+ N, D = Z.shape
73
+ if N <= K:
74
+ return Z.detach().cpu()
75
+ g = torch.Generator(device=device)
76
+ g.manual_seed(seed)
77
+ init_idx = torch.randperm(N, generator=g, device=device)[:K]
78
+ C = Z[init_idx].clone()
79
+ assign = torch.full((N,), -1, device=device, dtype=torch.long)
80
+ for _ in range(iters):
81
+ sim = Z @ C.t()
82
+ new_assign = sim.argmax(dim=1)
83
+ if (new_assign == assign).all():
84
+ assign = new_assign
85
+ break
86
+ assign = new_assign
87
+ C = torch.zeros(K, D, device=device, dtype=Z.dtype)
88
+ C.index_add_(0, assign, Z)
89
+ counts_raw = torch.bincount(assign, minlength=K)
90
+ empty = (counts_raw == 0)
91
+ counts = counts_raw.clamp_min(1).unsqueeze(1).to(Z.dtype)
92
+ C = C / counts
93
+ if empty.any():
94
+ ridx = torch.randperm(N, generator=g, device=device)[: int(empty.sum().item())]
95
+ C[empty] = Z[ridx]
96
+ C = torch.nn.functional.normalize(C, dim=1)
97
+ return C.detach().cpu()
98
+
99
+
100
+ class TripletDatasetWithID(Dataset):
101
+ def __init__(self, triplets: Sequence[TripletWithID], T_w, T_f, T_e):
102
+ self.triplets = list(triplets)
103
+ self.T_w = T_w
104
+ self.T_f = T_f
105
+ self.T_e = T_e
106
+
107
+ def __len__(self) -> int:
108
+ return len(self.triplets)
109
+
110
+ def __getitem__(self, idx: int):
111
+ pw, pf, pe, aid = self.triplets[idx]
112
+ try:
113
+ im_w = Image.open(pw).convert("RGB")
114
+ im_f = Image.open(pf).convert("RGB")
115
+ im_e = Image.open(pe).convert("RGB")
116
+ except (UnidentifiedImageError, OSError):
117
+ return None
118
+ return dict(whole=self.T_w(im_w), face=self.T_f(im_f), eyes=self.T_e(im_e), aid=int(aid))
119
+
120
+
121
+ def collate_triplets_with_id(batch):
122
+ batch = [b for b in batch if b is not None]
123
+ if not batch:
124
+ return None, None, None, None
125
+ Ws = torch.stack([b["whole"] for b in batch], dim=0)
126
+ Fs = torch.stack([b["face"] for b in batch], dim=0)
127
+ Es = torch.stack([b["eyes"] for b in batch], dim=0)
128
+ A = torch.tensor([b["aid"] for b in batch], dtype=torch.long)
129
+ return Ws, Fs, Es, A
130
+
131
+
132
+ def extract_embeddings_with_id(
133
+ *,
134
+ model: ts.TriViewStyleNet,
135
+ triplets: Sequence[TripletWithID],
136
+ T_w,
137
+ T_f,
138
+ T_e,
139
+ batch_size: int,
140
+ num_workers: int,
141
+ device: torch.device,
142
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
143
+ if not triplets:
144
+ return None, None
145
+ ds = TripletDatasetWithID(triplets, T_w, T_f, T_e)
146
+
147
+ def _run_loader(nw: int, pin: bool):
148
+ dl = DataLoader(
149
+ ds,
150
+ batch_size=batch_size,
151
+ shuffle=False,
152
+ num_workers=nw,
153
+ pin_memory=pin,
154
+ collate_fn=collate_triplets_with_id,
155
+ )
156
+ feats: List[torch.Tensor] = []
157
+ aids: List[torch.Tensor] = []
158
+ model.eval()
159
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=ts.amp_dtype, enabled=(device.type == "cuda")):
160
+ for Wb, Fb, Eb, Ab in dl:
161
+ if Wb is None:
162
+ continue
163
+ Wb = Wb.to(device, non_blocking=True)
164
+ Fb = Fb.to(device, non_blocking=True)
165
+ Eb = Eb.to(device, non_blocking=True)
166
+ views = {"whole": Wb, "face": Fb, "eyes": Eb}
167
+ masks = {k: torch.ones(Wb.size(0), dtype=torch.bool, device=device) for k in views}
168
+ z_fused, _, _ = model(views, masks)
169
+ feats.append(z_fused.detach().cpu())
170
+ aids.append(Ab.detach().cpu())
171
+ return feats, aids
172
+
173
+ try:
174
+ feats, aids = _run_loader(num_workers, pin=True)
175
+ except Exception:
176
+ feats, aids = _run_loader(0, pin=False)
177
+
178
+ if not feats:
179
+ return None, None
180
+ Z = torch.nn.functional.normalize(torch.cat(feats, dim=0), dim=1)
181
+ A = torch.cat(aids, dim=0).long()
182
+ return Z, A
183
+
184
+
185
+ def merge_dicts(d1: Dict[int, List], d2: Dict[int, List]) -> Dict[int, List]:
186
+ out = defaultdict(list)
187
+ for k, v in d1.items():
188
+ out[k].extend(list(v))
189
+ for k, v in d2.items():
190
+ out[k].extend(list(v))
191
+ return dict(out)
192
+
193
+
194
+ def main() -> None:
195
+ a = parse_args()
196
+ random.seed(a.seed)
197
+ torch.manual_seed(a.seed)
198
+
199
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
200
+ print("device:", device)
201
+
202
+ ck = torch.load(a.ckpt, map_location="cpu")
203
+ meta = ck.get("meta", {})
204
+ stage_i = int(meta.get("stage", 1))
205
+ stage = ts.cfg.stages[stage_i - 1]
206
+ print(f"loaded ckpt={a.ckpt} (stage={stage_i})")
207
+
208
+ # use deterministic transforms for prototype building/eval
209
+ T_w_val, T_f_val, T_e_val = ts.make_val_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])
210
+
211
+ train_ds = ts.TriViewDataset(ts.cfg.data_root, ts.cfg.folders, split="train", T_whole=T_w_val, T_face=T_f_val, T_eyes=T_e_val)
212
+ val_ds = ts.TriViewDataset(ts.cfg.data_root, ts.cfg.folders, split="val", T_whole=T_w_val, T_face=T_f_val, T_eyes=T_e_val)
213
+
214
+ # Persist artist names (label -> folder name) for downstream UX (e.g., Gradio UI)
215
+ # IDs are assigned by sorting artist directory names under dataset/.
216
+ label_names = [train_ds.id2artist[i] for i in range(train_ds.num_classes)]
217
+
218
+ # merge pools (train+val)
219
+ wholes_all = merge_dicts(train_ds.whole_paths_by_artist, val_ds.whole_paths_by_artist)
220
+ faces_all = merge_dicts(train_ds.face_paths_by_artist, val_ds.face_paths_by_artist)
221
+ eyes_all = merge_dicts(train_ds.eyes_paths_by_artist, val_ds.eyes_paths_by_artist)
222
+
223
+ build_data = {}
224
+ eval_data = {}
225
+ for aid in wholes_all.keys():
226
+ W_list = list({str(p) for p in wholes_all.get(aid, [])})
227
+ F_list = list({str(p) for p in faces_all.get(aid, [])})
228
+ E_list = list({str(p) for p in eyes_all.get(aid, [])})
229
+ random.shuffle(W_list)
230
+ random.shuffle(F_list)
231
+ random.shuffle(E_list)
232
+ if len(W_list) < 2 or len(F_list) < 2 or len(E_list) < 2:
233
+ continue
234
+ mw = max(1, int(len(W_list) * a.build_ratio))
235
+ mf = max(1, int(len(F_list) * a.build_ratio))
236
+ me = max(1, int(len(E_list) * a.build_ratio))
237
+ if mw == len(W_list):
238
+ mw -= 1
239
+ if mf == len(F_list):
240
+ mf -= 1
241
+ if me == len(E_list):
242
+ me -= 1
243
+ W_b, W_e = W_list[:mw], W_list[mw:]
244
+ F_b, F_e = F_list[:mf], F_list[mf:]
245
+ E_b, E_e = E_list[:me], E_list[me:]
246
+ if not (W_b and W_e and F_b and F_e and E_b and E_e):
247
+ continue
248
+ build_data[aid] = {"W": W_b, "F": F_b, "E": E_b}
249
+ eval_data[aid] = {"W": W_e, "F": F_e, "E": E_e}
250
+
251
+ print("valid artists:", len(build_data))
252
+
253
+ model = ts.TriViewStyleNet(out_dim=ts.cfg.embed_dim, mix_p=ts.cfg.mixstyle_p, share_backbone=True).to(device)
254
+ model = model.to(memory_format=torch.channels_last)
255
+ model.load_state_dict(ck["model"], strict=False)
256
+ model.eval()
257
+
258
+ # build triplets: use all build wholes once, random face/eyes from build pools
259
+ build_triplets: List[TripletWithID] = []
260
+ for aid, d in build_data.items():
261
+ for pw in d["W"]:
262
+ pf = random.choice(d["F"])
263
+ pe = random.choice(d["E"])
264
+ build_triplets.append((pw, pf, pe, int(aid)))
265
+ print("build triplets:", len(build_triplets))
266
+
267
+ Z_build, A_build = extract_embeddings_with_id(
268
+ model=model,
269
+ triplets=build_triplets,
270
+ T_w=T_w_val,
271
+ T_f=T_f_val,
272
+ T_e=T_e_val,
273
+ batch_size=a.batch_size,
274
+ num_workers=a.num_workers,
275
+ device=device,
276
+ )
277
+ if Z_build is None or A_build is None:
278
+ raise RuntimeError("No build embeddings extracted.")
279
+
280
+ # prototypes per artist
281
+ aid_to_idx = defaultdict(list)
282
+ for i, aid in enumerate(A_build.tolist()):
283
+ aid_to_idx[int(aid)].append(i)
284
+
285
+ proto_centers_list: List[torch.Tensor] = []
286
+ proto_labels_list: List[torch.Tensor] = []
287
+ for aid, idxs in aid_to_idx.items():
288
+ Zi = Z_build[torch.tensor(idxs, dtype=torch.long)]
289
+ if Zi.shape[0] <= a.k_per_artist:
290
+ proto_centers_list.append(Zi)
291
+ proto_labels_list.append(torch.full((Zi.shape[0],), aid, dtype=torch.long))
292
+ else:
293
+ centers = kmeans_cosine(Zi, K=a.k_per_artist, iters=20, seed=a.seed, device=device)
294
+ proto_centers_list.append(centers)
295
+ proto_labels_list.append(torch.full((a.k_per_artist,), aid, dtype=torch.long))
296
+
297
+ proto_centers = torch.cat(proto_centers_list, dim=0)
298
+ proto_labels = torch.cat(proto_labels_list, dim=0)
299
+ print("total prototypes:", proto_centers.shape[0])
300
+
301
+ # eval triplets: use all eval wholes once, random face/eyes from eval pools
302
+ eval_triplets: List[TripletWithID] = []
303
+ valid_proto_artists = set(proto_labels.unique().tolist())
304
+ for aid, d in eval_data.items():
305
+ if int(aid) not in valid_proto_artists:
306
+ continue
307
+ for pw in d["W"]:
308
+ pf = random.choice(d["F"])
309
+ pe = random.choice(d["E"])
310
+ eval_triplets.append((pw, pf, pe, int(aid)))
311
+ print("eval triplets:", len(eval_triplets))
312
+
313
+ Z_eval, Y_eval = extract_embeddings_with_id(
314
+ model=model,
315
+ triplets=eval_triplets,
316
+ T_w=T_w_val,
317
+ T_f=T_f_val,
318
+ T_e=T_e_val,
319
+ batch_size=a.batch_size,
320
+ num_workers=a.num_workers,
321
+ device=device,
322
+ )
323
+ if Z_eval is None or Y_eval is None:
324
+ raise RuntimeError("No eval embeddings extracted.")
325
+
326
+ # nearest-prototype classification (cosine)
327
+ with torch.no_grad():
328
+ C = torch.nn.functional.normalize(proto_centers.to(device), dim=1)
329
+ Z = torch.nn.functional.normalize(Z_eval.to(device), dim=1)
330
+ correct = 0
331
+ total = Z.shape[0]
332
+ for i in range(0, total, a.chunk_size):
333
+ zc = Z[i : i + a.chunk_size]
334
+ yc = Y_eval[i : i + a.chunk_size].to(device)
335
+ sim = zc @ C.t()
336
+ pred_idx = sim.argmax(dim=1)
337
+ pred_labels = proto_labels.to(device)[pred_idx]
338
+ correct += (pred_labels == yc).sum().item()
339
+ acc = correct / max(1, total)
340
+ print(f"prototype accuracy (strict 90/10): {acc:.4f}")
341
+
342
+ os.makedirs(os.path.dirname(a.out) or ".", exist_ok=True)
343
+ torch.save(
344
+ dict(
345
+ centers=proto_centers,
346
+ labels=proto_labels,
347
+ label_names=label_names,
348
+ k_per_artist=a.k_per_artist,
349
+ ckpt=a.ckpt,
350
+ split_method="90_10_strict_per_view_per_artist",
351
+ build_ratio=a.build_ratio,
352
+ acc=acc,
353
+ ),
354
+ a.out,
355
+ )
356
+ print("saved:", a.out)
357
+
358
+
359
+ if __name__ == "__main__":
360
+ main()
361
+
362
+
scripts/extract_faces_eyes.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Face -> (optional) eye extraction entrypoint.
5
+
6
+ This wraps `anime_face_eye_extract.py` so you can run:
7
+ python scripts/extract_faces_eyes.py --help
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import sys
13
+ from pathlib import Path
14
+
15
+
16
+ _ROOT = Path(__file__).resolve().parents[1]
17
+ if str(_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(_ROOT))
19
+
20
+
21
+ def main() -> None:
22
+ from anime_face_eye_extract import main as _main
23
+
24
+ _main()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
29
+
30
+
scripts/make_hf_space_bundle.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Build a Hugging Face Spaces-ready bundle directory from this repo.
5
+
6
+ The output folder can be uploaded (or git-pushed) to a new Gradio Space.
7
+ We intentionally do NOT rename files in this repo. Instead, the Space README
8
+ will specify `app_file: webui_gradio.py` to avoid conflicts with the `app/` package.
9
+
10
+ Usage:
11
+ python scripts/make_hf_space_bundle.py --out hf_space
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import shutil
18
+ from pathlib import Path
19
+
20
+
21
+ ROOT = Path(__file__).resolve().parents[1]
22
+
23
+
24
+ def copy_file(src: Path, dst: Path) -> None:
25
+ dst.parent.mkdir(parents=True, exist_ok=True)
26
+ shutil.copy2(src, dst)
27
+
28
+
29
+ def copy_tree(src: Path, dst: Path, *, ignore_globs: list[str] | None = None) -> None:
30
+ ignore = None
31
+ if ignore_globs:
32
+ ignore = shutil.ignore_patterns(*ignore_globs)
33
+ if dst.exists():
34
+ shutil.rmtree(dst)
35
+ shutil.copytree(src, dst, ignore=ignore)
36
+
37
+
38
+ def write_space_readme(dst: Path) -> None:
39
+ text = """---
40
+ title: ArtistEmbeddingClassifier
41
+ sdk: gradio
42
+ app_file: webui_gradio.py
43
+ license: gpl-3.0
44
+ ---
45
+
46
+ ### ArtistEmbeddingClassifier (Gradio Space)
47
+
48
+ This Space bundles the model checkpoint + prototype DB and runs the Gradio UI.
49
+
50
+ Notes:
51
+ - This project is GPL-3.0.
52
+ - `yolov5_anime/` is from [zymk9/yolov5_anime](https://github.com/zymk9/yolov5_anime) (GPL-3.0).
53
+ - `anime-eyes-cascade.xml` is from [recette-lemon/Haar-Cascade-Anime-Eye-Detector](https://github.com/recette-lemon/Haar-Cascade-Anime-Eye-Detector) (GPL-3.0).
54
+ """
55
+ (dst / "README.md").write_text(text, encoding="utf-8")
56
+
57
+
58
+ def write_space_requirements(dst: Path) -> None:
59
+ # IMPORTANT for Spaces:
60
+ # - HF GPU base images already install torch + gradio + spaces.
61
+ # - If we pin/downgrade these here, pip will try to replace huge packages and may fail.
62
+ # Keep this list minimal and only add what is NOT guaranteed by the base image.
63
+ text = """pillow
64
+ pyyaml
65
+ tqdm
66
+
67
+ # OpenCV for face/eye extraction (headless build for Spaces)
68
+ opencv-python-headless
69
+ """
70
+ (dst / "requirements.txt").write_text(text, encoding="utf-8")
71
+
72
+
73
+ def write_space_packages(dst: Path) -> None:
74
+ # Helps OpenCV on Spaces.
75
+ text = """libgl1
76
+ libglib2.0-0
77
+ """
78
+ (dst / "packages.txt").write_text(text, encoding="utf-8")
79
+
80
+
81
+ def write_lfs_gitattributes(dst: Path) -> None:
82
+ # If you push via git, this ensures large weights are handled via LFS.
83
+ text = """*.pt filter=lfs diff=lfs merge=lfs -text
84
+ """
85
+ (dst / ".gitattributes").write_text(text, encoding="utf-8")
86
+
87
+
88
+ def main() -> None:
89
+ ap = argparse.ArgumentParser(description="Create Hugging Face Space bundle directory")
90
+ ap.add_argument("--out", type=str, default="hf_space", help="Output folder name")
91
+ args = ap.parse_args()
92
+
93
+ out_dir = (ROOT / args.out).resolve()
94
+ if out_dir.exists():
95
+ shutil.rmtree(out_dir)
96
+ out_dir.mkdir(parents=True, exist_ok=True)
97
+
98
+ # Core app code
99
+ copy_file(ROOT / "webui_gradio.py", out_dir / "webui_gradio.py")
100
+ copy_tree(ROOT / "app", out_dir / "app", ignore_globs=["__pycache__"])
101
+
102
+ # Assets required by the UI
103
+ copy_file(ROOT / "anime-eyes-cascade.xml", out_dir / "anime-eyes-cascade.xml")
104
+ copy_file(ROOT / "yolov5x_anime.pt", out_dir / "yolov5x_anime.pt")
105
+
106
+ # Bundle checkpoints/prototypes
107
+ (out_dir / "checkpoints_style").mkdir(exist_ok=True)
108
+ copy_file(ROOT / "checkpoints_style" / "stage3_epoch24.pt", out_dir / "checkpoints_style" / "stage3_epoch24.pt")
109
+ copy_file(
110
+ ROOT / "checkpoints_style" / "per_artist_prototypes_90_10_full.pt",
111
+ out_dir / "checkpoints_style" / "per_artist_prototypes_90_10_full.pt",
112
+ )
113
+
114
+ # Vendor yolov5_anime (strip heavy demo assets)
115
+ copy_tree(
116
+ ROOT / "yolov5_anime",
117
+ out_dir / "yolov5_anime",
118
+ ignore_globs=[
119
+ "__pycache__",
120
+ ".git",
121
+ "inference",
122
+ "tutorial.ipynb",
123
+ "Dockerfile",
124
+ # We bundle yolov5x_anime.pt at repo root; don't include extra weights.
125
+ "*.pt",
126
+ "weights/*.pt",
127
+ ],
128
+ )
129
+
130
+ # Licensing/attribution
131
+ for fn in ("LICENSE", "THIRD_PARTY_NOTICES.md"):
132
+ if (ROOT / fn).exists():
133
+ copy_file(ROOT / fn, out_dir / fn)
134
+
135
+ # Space metadata
136
+ write_space_readme(out_dir)
137
+ write_space_requirements(out_dir)
138
+ write_space_packages(out_dir)
139
+ write_lfs_gitattributes(out_dir)
140
+
141
+ print("✅ Created Space bundle at:", out_dir)
142
+ print("Next: upload/push the contents of that folder to your Hugging Face Space repo.")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
147
+
148
+
scripts/train_ddp.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ DDP training entrypoint.
5
+
6
+ This wraps `train_style_ddp.py` so you can run:
7
+ python scripts/train_ddp.py
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import sys
13
+ from pathlib import Path
14
+
15
+
16
+ _ROOT = Path(__file__).resolve().parents[1]
17
+ if str(_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(_ROOT))
19
+
20
+
21
+ def main() -> None:
22
+ import train_style_ddp as ts
23
+
24
+ ts.run_ddp_training()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
29
+
30
+
scripts/upgrade_proto_db_add_names.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Upgrade an existing prototype DB .pt file to include artist names (label_names).
5
+
6
+ This is useful for older prototype files that only store:
7
+ - centers: [N, D]
8
+ - labels: [N]
9
+
10
+ We infer label_names from `dataset/` folder (sorted artist directories), matching
11
+ `train_style_ddp.TriViewDataset` label assignment.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+
22
+ def infer_label_names(dataset_dir: Path) -> list[str]:
23
+ if not dataset_dir.exists():
24
+ raise FileNotFoundError(f"dataset dir not found: {dataset_dir}")
25
+ names = sorted([p.name for p in dataset_dir.iterdir() if p.is_dir()])
26
+ if not names:
27
+ raise RuntimeError(f"No artist folders found under: {dataset_dir}")
28
+ return names
29
+
30
+
31
+ def main() -> None:
32
+ p = argparse.ArgumentParser(description="Add label_names to an existing prototype DB .pt")
33
+ p.add_argument("--in", dest="in_path", required=True, help="Input .pt prototype file")
34
+ p.add_argument("--out", dest="out_path", default=None, help="Output .pt (default: overwrite input)")
35
+ p.add_argument("--dataset-dir", type=str, default="dataset", help="Dataset root to infer artist names from")
36
+ args = p.parse_args()
37
+
38
+ in_path = Path(args.in_path)
39
+ out_path = Path(args.out_path) if args.out_path else in_path
40
+ dataset_dir = Path(args.dataset_dir)
41
+
42
+ obj = torch.load(str(in_path), map_location="cpu")
43
+ if not isinstance(obj, dict) or "centers" not in obj or "labels" not in obj:
44
+ raise ValueError("Unsupported prototype file format (expected dict with centers+labels).")
45
+
46
+ if "label_names" in obj and isinstance(obj["label_names"], list) and obj["label_names"]:
47
+ print("label_names already present; nothing to do.")
48
+ if out_path != in_path:
49
+ torch.save(obj, str(out_path))
50
+ print("saved:", out_path)
51
+ return
52
+
53
+ label_names = infer_label_names(dataset_dir)
54
+ obj["label_names"] = label_names
55
+
56
+ out_path.parent.mkdir(parents=True, exist_ok=True)
57
+ torch.save(obj, str(out_path))
58
+ print("saved:", out_path)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
63
+
64
+