Spaces:
Sleeping
Sleeping
iljung1106
commited on
Commit
·
4f81869
1
Parent(s):
e6ecafe
Add some scripts.
Browse files- scripts/__init__.py +3 -0
- scripts/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/__pycache__/crawl_dataset.cpython-310.pyc +0 -0
- scripts/__pycache__/eval_prototypes_halfval.cpython-310.pyc +0 -0
- scripts/__pycache__/eval_prototypes_strict_90_10.cpython-310.pyc +0 -0
- scripts/__pycache__/extract_faces_eyes.cpython-310.pyc +0 -0
- scripts/__pycache__/train_ddp.cpython-310.pyc +0 -0
- scripts/crawl_dataset.py +30 -0
- scripts/eval_prototypes_strict_90_10.py +362 -0
- scripts/extract_faces_eyes.py +30 -0
- scripts/make_hf_space_bundle.py +148 -0
- scripts/train_ddp.py +30 -0
- scripts/upgrade_proto_db_add_names.py +64 -0
scripts/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file makes `scripts/` a Python package so entrypoints can be imported reliably.
|
| 2 |
+
|
| 3 |
+
|
scripts/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
scripts/__pycache__/crawl_dataset.cpython-310.pyc
ADDED
|
Binary file (613 Bytes). View file
|
|
|
scripts/__pycache__/eval_prototypes_halfval.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
scripts/__pycache__/eval_prototypes_strict_90_10.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
scripts/__pycache__/extract_faces_eyes.cpython-310.pyc
ADDED
|
Binary file (611 Bytes). View file
|
|
|
scripts/__pycache__/train_ddp.cpython-310.pyc
ADDED
|
Binary file (552 Bytes). View file
|
|
|
scripts/crawl_dataset.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Dataset crawling entrypoint (Danbooru artist list via Selenium + Gelbooru downloads).
|
| 5 |
+
|
| 6 |
+
This wraps `crawler_api.py` so you can run:
|
| 7 |
+
python scripts/crawl_dataset.py --help
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_ROOT = Path(__file__).resolve().parents[1]
|
| 17 |
+
if str(_ROOT) not in sys.path:
|
| 18 |
+
sys.path.insert(0, str(_ROOT))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main() -> None:
|
| 22 |
+
import crawler_api
|
| 23 |
+
|
| 24 |
+
crawler_api.main_cli()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
main()
|
| 29 |
+
|
| 30 |
+
|
scripts/eval_prototypes_strict_90_10.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Prototype evaluation (strict 90/10 split per view per artist, using merged train+val pools).
|
| 5 |
+
|
| 6 |
+
This script mirrors the "strict 90/10 full coverage" prototype-eval logic.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Dict, List, Optional, Sequence, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from PIL import Image, UnidentifiedImageError
|
| 20 |
+
from torch.utils.data import DataLoader, Dataset
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
_ROOT = Path(__file__).resolve().parents[1]
|
| 26 |
+
if str(_ROOT) not in sys.path:
|
| 27 |
+
sys.path.insert(0, str(_ROOT))
|
| 28 |
+
|
| 29 |
+
import train_style_ddp as ts
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
TripletWithID = Tuple[str, str, str, int]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Args:
|
| 37 |
+
ckpt: str
|
| 38 |
+
out: str
|
| 39 |
+
k_per_artist: int
|
| 40 |
+
build_ratio: float
|
| 41 |
+
batch_size: int
|
| 42 |
+
num_workers: int
|
| 43 |
+
seed: int
|
| 44 |
+
chunk_size: int
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def parse_args() -> Args:
|
| 48 |
+
p = argparse.ArgumentParser(description="Eval prototypes (strict 90/10 split per view)")
|
| 49 |
+
p.add_argument("--ckpt", type=str, default="./checkpoints_style/stage3_epoch24.pt")
|
| 50 |
+
p.add_argument("--out", type=str, default="./checkpoints_style/per_artist_prototypes_90_10_full.pt")
|
| 51 |
+
p.add_argument("--k-per-artist", type=int, default=4)
|
| 52 |
+
p.add_argument("--build-ratio", type=float, default=0.9)
|
| 53 |
+
p.add_argument("--batch-size", type=int, default=64)
|
| 54 |
+
p.add_argument("--num-workers", type=int, default=0, help="0 is safest on Windows/spawn.")
|
| 55 |
+
p.add_argument("--seed", type=int, default=ts.cfg.seed)
|
| 56 |
+
p.add_argument("--chunk-size", type=int, default=5000)
|
| 57 |
+
a = p.parse_args()
|
| 58 |
+
return Args(
|
| 59 |
+
ckpt=a.ckpt,
|
| 60 |
+
out=a.out,
|
| 61 |
+
k_per_artist=a.k_per_artist,
|
| 62 |
+
build_ratio=a.build_ratio,
|
| 63 |
+
batch_size=a.batch_size,
|
| 64 |
+
num_workers=a.num_workers,
|
| 65 |
+
seed=a.seed,
|
| 66 |
+
chunk_size=a.chunk_size,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def kmeans_cosine(Z_cpu: torch.Tensor, K: int, *, iters: int = 20, seed: int = 1337, device: torch.device) -> torch.Tensor:
|
| 71 |
+
Z = torch.nn.functional.normalize(Z_cpu.to(device), dim=1)
|
| 72 |
+
N, D = Z.shape
|
| 73 |
+
if N <= K:
|
| 74 |
+
return Z.detach().cpu()
|
| 75 |
+
g = torch.Generator(device=device)
|
| 76 |
+
g.manual_seed(seed)
|
| 77 |
+
init_idx = torch.randperm(N, generator=g, device=device)[:K]
|
| 78 |
+
C = Z[init_idx].clone()
|
| 79 |
+
assign = torch.full((N,), -1, device=device, dtype=torch.long)
|
| 80 |
+
for _ in range(iters):
|
| 81 |
+
sim = Z @ C.t()
|
| 82 |
+
new_assign = sim.argmax(dim=1)
|
| 83 |
+
if (new_assign == assign).all():
|
| 84 |
+
assign = new_assign
|
| 85 |
+
break
|
| 86 |
+
assign = new_assign
|
| 87 |
+
C = torch.zeros(K, D, device=device, dtype=Z.dtype)
|
| 88 |
+
C.index_add_(0, assign, Z)
|
| 89 |
+
counts_raw = torch.bincount(assign, minlength=K)
|
| 90 |
+
empty = (counts_raw == 0)
|
| 91 |
+
counts = counts_raw.clamp_min(1).unsqueeze(1).to(Z.dtype)
|
| 92 |
+
C = C / counts
|
| 93 |
+
if empty.any():
|
| 94 |
+
ridx = torch.randperm(N, generator=g, device=device)[: int(empty.sum().item())]
|
| 95 |
+
C[empty] = Z[ridx]
|
| 96 |
+
C = torch.nn.functional.normalize(C, dim=1)
|
| 97 |
+
return C.detach().cpu()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class TripletDatasetWithID(Dataset):
|
| 101 |
+
def __init__(self, triplets: Sequence[TripletWithID], T_w, T_f, T_e):
|
| 102 |
+
self.triplets = list(triplets)
|
| 103 |
+
self.T_w = T_w
|
| 104 |
+
self.T_f = T_f
|
| 105 |
+
self.T_e = T_e
|
| 106 |
+
|
| 107 |
+
def __len__(self) -> int:
|
| 108 |
+
return len(self.triplets)
|
| 109 |
+
|
| 110 |
+
def __getitem__(self, idx: int):
|
| 111 |
+
pw, pf, pe, aid = self.triplets[idx]
|
| 112 |
+
try:
|
| 113 |
+
im_w = Image.open(pw).convert("RGB")
|
| 114 |
+
im_f = Image.open(pf).convert("RGB")
|
| 115 |
+
im_e = Image.open(pe).convert("RGB")
|
| 116 |
+
except (UnidentifiedImageError, OSError):
|
| 117 |
+
return None
|
| 118 |
+
return dict(whole=self.T_w(im_w), face=self.T_f(im_f), eyes=self.T_e(im_e), aid=int(aid))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def collate_triplets_with_id(batch):
|
| 122 |
+
batch = [b for b in batch if b is not None]
|
| 123 |
+
if not batch:
|
| 124 |
+
return None, None, None, None
|
| 125 |
+
Ws = torch.stack([b["whole"] for b in batch], dim=0)
|
| 126 |
+
Fs = torch.stack([b["face"] for b in batch], dim=0)
|
| 127 |
+
Es = torch.stack([b["eyes"] for b in batch], dim=0)
|
| 128 |
+
A = torch.tensor([b["aid"] for b in batch], dtype=torch.long)
|
| 129 |
+
return Ws, Fs, Es, A
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def extract_embeddings_with_id(
|
| 133 |
+
*,
|
| 134 |
+
model: ts.TriViewStyleNet,
|
| 135 |
+
triplets: Sequence[TripletWithID],
|
| 136 |
+
T_w,
|
| 137 |
+
T_f,
|
| 138 |
+
T_e,
|
| 139 |
+
batch_size: int,
|
| 140 |
+
num_workers: int,
|
| 141 |
+
device: torch.device,
|
| 142 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 143 |
+
if not triplets:
|
| 144 |
+
return None, None
|
| 145 |
+
ds = TripletDatasetWithID(triplets, T_w, T_f, T_e)
|
| 146 |
+
|
| 147 |
+
def _run_loader(nw: int, pin: bool):
|
| 148 |
+
dl = DataLoader(
|
| 149 |
+
ds,
|
| 150 |
+
batch_size=batch_size,
|
| 151 |
+
shuffle=False,
|
| 152 |
+
num_workers=nw,
|
| 153 |
+
pin_memory=pin,
|
| 154 |
+
collate_fn=collate_triplets_with_id,
|
| 155 |
+
)
|
| 156 |
+
feats: List[torch.Tensor] = []
|
| 157 |
+
aids: List[torch.Tensor] = []
|
| 158 |
+
model.eval()
|
| 159 |
+
with torch.no_grad(), torch.amp.autocast("cuda", dtype=ts.amp_dtype, enabled=(device.type == "cuda")):
|
| 160 |
+
for Wb, Fb, Eb, Ab in dl:
|
| 161 |
+
if Wb is None:
|
| 162 |
+
continue
|
| 163 |
+
Wb = Wb.to(device, non_blocking=True)
|
| 164 |
+
Fb = Fb.to(device, non_blocking=True)
|
| 165 |
+
Eb = Eb.to(device, non_blocking=True)
|
| 166 |
+
views = {"whole": Wb, "face": Fb, "eyes": Eb}
|
| 167 |
+
masks = {k: torch.ones(Wb.size(0), dtype=torch.bool, device=device) for k in views}
|
| 168 |
+
z_fused, _, _ = model(views, masks)
|
| 169 |
+
feats.append(z_fused.detach().cpu())
|
| 170 |
+
aids.append(Ab.detach().cpu())
|
| 171 |
+
return feats, aids
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
feats, aids = _run_loader(num_workers, pin=True)
|
| 175 |
+
except Exception:
|
| 176 |
+
feats, aids = _run_loader(0, pin=False)
|
| 177 |
+
|
| 178 |
+
if not feats:
|
| 179 |
+
return None, None
|
| 180 |
+
Z = torch.nn.functional.normalize(torch.cat(feats, dim=0), dim=1)
|
| 181 |
+
A = torch.cat(aids, dim=0).long()
|
| 182 |
+
return Z, A
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def merge_dicts(d1: Dict[int, List], d2: Dict[int, List]) -> Dict[int, List]:
|
| 186 |
+
out = defaultdict(list)
|
| 187 |
+
for k, v in d1.items():
|
| 188 |
+
out[k].extend(list(v))
|
| 189 |
+
for k, v in d2.items():
|
| 190 |
+
out[k].extend(list(v))
|
| 191 |
+
return dict(out)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def main() -> None:
|
| 195 |
+
a = parse_args()
|
| 196 |
+
random.seed(a.seed)
|
| 197 |
+
torch.manual_seed(a.seed)
|
| 198 |
+
|
| 199 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 200 |
+
print("device:", device)
|
| 201 |
+
|
| 202 |
+
ck = torch.load(a.ckpt, map_location="cpu")
|
| 203 |
+
meta = ck.get("meta", {})
|
| 204 |
+
stage_i = int(meta.get("stage", 1))
|
| 205 |
+
stage = ts.cfg.stages[stage_i - 1]
|
| 206 |
+
print(f"loaded ckpt={a.ckpt} (stage={stage_i})")
|
| 207 |
+
|
| 208 |
+
# use deterministic transforms for prototype building/eval
|
| 209 |
+
T_w_val, T_f_val, T_e_val = ts.make_val_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])
|
| 210 |
+
|
| 211 |
+
train_ds = ts.TriViewDataset(ts.cfg.data_root, ts.cfg.folders, split="train", T_whole=T_w_val, T_face=T_f_val, T_eyes=T_e_val)
|
| 212 |
+
val_ds = ts.TriViewDataset(ts.cfg.data_root, ts.cfg.folders, split="val", T_whole=T_w_val, T_face=T_f_val, T_eyes=T_e_val)
|
| 213 |
+
|
| 214 |
+
# Persist artist names (label -> folder name) for downstream UX (e.g., Gradio UI)
|
| 215 |
+
# IDs are assigned by sorting artist directory names under dataset/.
|
| 216 |
+
label_names = [train_ds.id2artist[i] for i in range(train_ds.num_classes)]
|
| 217 |
+
|
| 218 |
+
# merge pools (train+val)
|
| 219 |
+
wholes_all = merge_dicts(train_ds.whole_paths_by_artist, val_ds.whole_paths_by_artist)
|
| 220 |
+
faces_all = merge_dicts(train_ds.face_paths_by_artist, val_ds.face_paths_by_artist)
|
| 221 |
+
eyes_all = merge_dicts(train_ds.eyes_paths_by_artist, val_ds.eyes_paths_by_artist)
|
| 222 |
+
|
| 223 |
+
build_data = {}
|
| 224 |
+
eval_data = {}
|
| 225 |
+
for aid in wholes_all.keys():
|
| 226 |
+
W_list = list({str(p) for p in wholes_all.get(aid, [])})
|
| 227 |
+
F_list = list({str(p) for p in faces_all.get(aid, [])})
|
| 228 |
+
E_list = list({str(p) for p in eyes_all.get(aid, [])})
|
| 229 |
+
random.shuffle(W_list)
|
| 230 |
+
random.shuffle(F_list)
|
| 231 |
+
random.shuffle(E_list)
|
| 232 |
+
if len(W_list) < 2 or len(F_list) < 2 or len(E_list) < 2:
|
| 233 |
+
continue
|
| 234 |
+
mw = max(1, int(len(W_list) * a.build_ratio))
|
| 235 |
+
mf = max(1, int(len(F_list) * a.build_ratio))
|
| 236 |
+
me = max(1, int(len(E_list) * a.build_ratio))
|
| 237 |
+
if mw == len(W_list):
|
| 238 |
+
mw -= 1
|
| 239 |
+
if mf == len(F_list):
|
| 240 |
+
mf -= 1
|
| 241 |
+
if me == len(E_list):
|
| 242 |
+
me -= 1
|
| 243 |
+
W_b, W_e = W_list[:mw], W_list[mw:]
|
| 244 |
+
F_b, F_e = F_list[:mf], F_list[mf:]
|
| 245 |
+
E_b, E_e = E_list[:me], E_list[me:]
|
| 246 |
+
if not (W_b and W_e and F_b and F_e and E_b and E_e):
|
| 247 |
+
continue
|
| 248 |
+
build_data[aid] = {"W": W_b, "F": F_b, "E": E_b}
|
| 249 |
+
eval_data[aid] = {"W": W_e, "F": F_e, "E": E_e}
|
| 250 |
+
|
| 251 |
+
print("valid artists:", len(build_data))
|
| 252 |
+
|
| 253 |
+
model = ts.TriViewStyleNet(out_dim=ts.cfg.embed_dim, mix_p=ts.cfg.mixstyle_p, share_backbone=True).to(device)
|
| 254 |
+
model = model.to(memory_format=torch.channels_last)
|
| 255 |
+
model.load_state_dict(ck["model"], strict=False)
|
| 256 |
+
model.eval()
|
| 257 |
+
|
| 258 |
+
# build triplets: use all build wholes once, random face/eyes from build pools
|
| 259 |
+
build_triplets: List[TripletWithID] = []
|
| 260 |
+
for aid, d in build_data.items():
|
| 261 |
+
for pw in d["W"]:
|
| 262 |
+
pf = random.choice(d["F"])
|
| 263 |
+
pe = random.choice(d["E"])
|
| 264 |
+
build_triplets.append((pw, pf, pe, int(aid)))
|
| 265 |
+
print("build triplets:", len(build_triplets))
|
| 266 |
+
|
| 267 |
+
Z_build, A_build = extract_embeddings_with_id(
|
| 268 |
+
model=model,
|
| 269 |
+
triplets=build_triplets,
|
| 270 |
+
T_w=T_w_val,
|
| 271 |
+
T_f=T_f_val,
|
| 272 |
+
T_e=T_e_val,
|
| 273 |
+
batch_size=a.batch_size,
|
| 274 |
+
num_workers=a.num_workers,
|
| 275 |
+
device=device,
|
| 276 |
+
)
|
| 277 |
+
if Z_build is None or A_build is None:
|
| 278 |
+
raise RuntimeError("No build embeddings extracted.")
|
| 279 |
+
|
| 280 |
+
# prototypes per artist
|
| 281 |
+
aid_to_idx = defaultdict(list)
|
| 282 |
+
for i, aid in enumerate(A_build.tolist()):
|
| 283 |
+
aid_to_idx[int(aid)].append(i)
|
| 284 |
+
|
| 285 |
+
proto_centers_list: List[torch.Tensor] = []
|
| 286 |
+
proto_labels_list: List[torch.Tensor] = []
|
| 287 |
+
for aid, idxs in aid_to_idx.items():
|
| 288 |
+
Zi = Z_build[torch.tensor(idxs, dtype=torch.long)]
|
| 289 |
+
if Zi.shape[0] <= a.k_per_artist:
|
| 290 |
+
proto_centers_list.append(Zi)
|
| 291 |
+
proto_labels_list.append(torch.full((Zi.shape[0],), aid, dtype=torch.long))
|
| 292 |
+
else:
|
| 293 |
+
centers = kmeans_cosine(Zi, K=a.k_per_artist, iters=20, seed=a.seed, device=device)
|
| 294 |
+
proto_centers_list.append(centers)
|
| 295 |
+
proto_labels_list.append(torch.full((a.k_per_artist,), aid, dtype=torch.long))
|
| 296 |
+
|
| 297 |
+
proto_centers = torch.cat(proto_centers_list, dim=0)
|
| 298 |
+
proto_labels = torch.cat(proto_labels_list, dim=0)
|
| 299 |
+
print("total prototypes:", proto_centers.shape[0])
|
| 300 |
+
|
| 301 |
+
# eval triplets: use all eval wholes once, random face/eyes from eval pools
|
| 302 |
+
eval_triplets: List[TripletWithID] = []
|
| 303 |
+
valid_proto_artists = set(proto_labels.unique().tolist())
|
| 304 |
+
for aid, d in eval_data.items():
|
| 305 |
+
if int(aid) not in valid_proto_artists:
|
| 306 |
+
continue
|
| 307 |
+
for pw in d["W"]:
|
| 308 |
+
pf = random.choice(d["F"])
|
| 309 |
+
pe = random.choice(d["E"])
|
| 310 |
+
eval_triplets.append((pw, pf, pe, int(aid)))
|
| 311 |
+
print("eval triplets:", len(eval_triplets))
|
| 312 |
+
|
| 313 |
+
Z_eval, Y_eval = extract_embeddings_with_id(
|
| 314 |
+
model=model,
|
| 315 |
+
triplets=eval_triplets,
|
| 316 |
+
T_w=T_w_val,
|
| 317 |
+
T_f=T_f_val,
|
| 318 |
+
T_e=T_e_val,
|
| 319 |
+
batch_size=a.batch_size,
|
| 320 |
+
num_workers=a.num_workers,
|
| 321 |
+
device=device,
|
| 322 |
+
)
|
| 323 |
+
if Z_eval is None or Y_eval is None:
|
| 324 |
+
raise RuntimeError("No eval embeddings extracted.")
|
| 325 |
+
|
| 326 |
+
# nearest-prototype classification (cosine)
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
C = torch.nn.functional.normalize(proto_centers.to(device), dim=1)
|
| 329 |
+
Z = torch.nn.functional.normalize(Z_eval.to(device), dim=1)
|
| 330 |
+
correct = 0
|
| 331 |
+
total = Z.shape[0]
|
| 332 |
+
for i in range(0, total, a.chunk_size):
|
| 333 |
+
zc = Z[i : i + a.chunk_size]
|
| 334 |
+
yc = Y_eval[i : i + a.chunk_size].to(device)
|
| 335 |
+
sim = zc @ C.t()
|
| 336 |
+
pred_idx = sim.argmax(dim=1)
|
| 337 |
+
pred_labels = proto_labels.to(device)[pred_idx]
|
| 338 |
+
correct += (pred_labels == yc).sum().item()
|
| 339 |
+
acc = correct / max(1, total)
|
| 340 |
+
print(f"prototype accuracy (strict 90/10): {acc:.4f}")
|
| 341 |
+
|
| 342 |
+
os.makedirs(os.path.dirname(a.out) or ".", exist_ok=True)
|
| 343 |
+
torch.save(
|
| 344 |
+
dict(
|
| 345 |
+
centers=proto_centers,
|
| 346 |
+
labels=proto_labels,
|
| 347 |
+
label_names=label_names,
|
| 348 |
+
k_per_artist=a.k_per_artist,
|
| 349 |
+
ckpt=a.ckpt,
|
| 350 |
+
split_method="90_10_strict_per_view_per_artist",
|
| 351 |
+
build_ratio=a.build_ratio,
|
| 352 |
+
acc=acc,
|
| 353 |
+
),
|
| 354 |
+
a.out,
|
| 355 |
+
)
|
| 356 |
+
print("saved:", a.out)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
if __name__ == "__main__":
|
| 360 |
+
main()
|
| 361 |
+
|
| 362 |
+
|
scripts/extract_faces_eyes.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Face -> (optional) eye extraction entrypoint.
|
| 5 |
+
|
| 6 |
+
This wraps `anime_face_eye_extract.py` so you can run:
|
| 7 |
+
python scripts/extract_faces_eyes.py --help
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_ROOT = Path(__file__).resolve().parents[1]
|
| 17 |
+
if str(_ROOT) not in sys.path:
|
| 18 |
+
sys.path.insert(0, str(_ROOT))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main() -> None:
|
| 22 |
+
from anime_face_eye_extract import main as _main
|
| 23 |
+
|
| 24 |
+
_main()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
main()
|
| 29 |
+
|
| 30 |
+
|
scripts/make_hf_space_bundle.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Build a Hugging Face Spaces-ready bundle directory from this repo.
|
| 5 |
+
|
| 6 |
+
The output folder can be uploaded (or git-pushed) to a new Gradio Space.
|
| 7 |
+
We intentionally do NOT rename files in this repo. Instead, the Space README
|
| 8 |
+
will specify `app_file: webui_gradio.py` to avoid conflicts with the `app/` package.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python scripts/make_hf_space_bundle.py --out hf_space
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import shutil
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def copy_file(src: Path, dst: Path) -> None:
|
| 25 |
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
| 26 |
+
shutil.copy2(src, dst)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def copy_tree(src: Path, dst: Path, *, ignore_globs: list[str] | None = None) -> None:
|
| 30 |
+
ignore = None
|
| 31 |
+
if ignore_globs:
|
| 32 |
+
ignore = shutil.ignore_patterns(*ignore_globs)
|
| 33 |
+
if dst.exists():
|
| 34 |
+
shutil.rmtree(dst)
|
| 35 |
+
shutil.copytree(src, dst, ignore=ignore)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def write_space_readme(dst: Path) -> None:
|
| 39 |
+
text = """---
|
| 40 |
+
title: ArtistEmbeddingClassifier
|
| 41 |
+
sdk: gradio
|
| 42 |
+
app_file: webui_gradio.py
|
| 43 |
+
license: gpl-3.0
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
### ArtistEmbeddingClassifier (Gradio Space)
|
| 47 |
+
|
| 48 |
+
This Space bundles the model checkpoint + prototype DB and runs the Gradio UI.
|
| 49 |
+
|
| 50 |
+
Notes:
|
| 51 |
+
- This project is GPL-3.0.
|
| 52 |
+
- `yolov5_anime/` is from [zymk9/yolov5_anime](https://github.com/zymk9/yolov5_anime) (GPL-3.0).
|
| 53 |
+
- `anime-eyes-cascade.xml` is from [recette-lemon/Haar-Cascade-Anime-Eye-Detector](https://github.com/recette-lemon/Haar-Cascade-Anime-Eye-Detector) (GPL-3.0).
|
| 54 |
+
"""
|
| 55 |
+
(dst / "README.md").write_text(text, encoding="utf-8")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def write_space_requirements(dst: Path) -> None:
|
| 59 |
+
# IMPORTANT for Spaces:
|
| 60 |
+
# - HF GPU base images already install torch + gradio + spaces.
|
| 61 |
+
# - If we pin/downgrade these here, pip will try to replace huge packages and may fail.
|
| 62 |
+
# Keep this list minimal and only add what is NOT guaranteed by the base image.
|
| 63 |
+
text = """pillow
|
| 64 |
+
pyyaml
|
| 65 |
+
tqdm
|
| 66 |
+
|
| 67 |
+
# OpenCV for face/eye extraction (headless build for Spaces)
|
| 68 |
+
opencv-python-headless
|
| 69 |
+
"""
|
| 70 |
+
(dst / "requirements.txt").write_text(text, encoding="utf-8")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def write_space_packages(dst: Path) -> None:
|
| 74 |
+
# Helps OpenCV on Spaces.
|
| 75 |
+
text = """libgl1
|
| 76 |
+
libglib2.0-0
|
| 77 |
+
"""
|
| 78 |
+
(dst / "packages.txt").write_text(text, encoding="utf-8")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def write_lfs_gitattributes(dst: Path) -> None:
|
| 82 |
+
# If you push via git, this ensures large weights are handled via LFS.
|
| 83 |
+
text = """*.pt filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
"""
|
| 85 |
+
(dst / ".gitattributes").write_text(text, encoding="utf-8")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def main() -> None:
|
| 89 |
+
ap = argparse.ArgumentParser(description="Create Hugging Face Space bundle directory")
|
| 90 |
+
ap.add_argument("--out", type=str, default="hf_space", help="Output folder name")
|
| 91 |
+
args = ap.parse_args()
|
| 92 |
+
|
| 93 |
+
out_dir = (ROOT / args.out).resolve()
|
| 94 |
+
if out_dir.exists():
|
| 95 |
+
shutil.rmtree(out_dir)
|
| 96 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
|
| 98 |
+
# Core app code
|
| 99 |
+
copy_file(ROOT / "webui_gradio.py", out_dir / "webui_gradio.py")
|
| 100 |
+
copy_tree(ROOT / "app", out_dir / "app", ignore_globs=["__pycache__"])
|
| 101 |
+
|
| 102 |
+
# Assets required by the UI
|
| 103 |
+
copy_file(ROOT / "anime-eyes-cascade.xml", out_dir / "anime-eyes-cascade.xml")
|
| 104 |
+
copy_file(ROOT / "yolov5x_anime.pt", out_dir / "yolov5x_anime.pt")
|
| 105 |
+
|
| 106 |
+
# Bundle checkpoints/prototypes
|
| 107 |
+
(out_dir / "checkpoints_style").mkdir(exist_ok=True)
|
| 108 |
+
copy_file(ROOT / "checkpoints_style" / "stage3_epoch24.pt", out_dir / "checkpoints_style" / "stage3_epoch24.pt")
|
| 109 |
+
copy_file(
|
| 110 |
+
ROOT / "checkpoints_style" / "per_artist_prototypes_90_10_full.pt",
|
| 111 |
+
out_dir / "checkpoints_style" / "per_artist_prototypes_90_10_full.pt",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Vendor yolov5_anime (strip heavy demo assets)
|
| 115 |
+
copy_tree(
|
| 116 |
+
ROOT / "yolov5_anime",
|
| 117 |
+
out_dir / "yolov5_anime",
|
| 118 |
+
ignore_globs=[
|
| 119 |
+
"__pycache__",
|
| 120 |
+
".git",
|
| 121 |
+
"inference",
|
| 122 |
+
"tutorial.ipynb",
|
| 123 |
+
"Dockerfile",
|
| 124 |
+
# We bundle yolov5x_anime.pt at repo root; don't include extra weights.
|
| 125 |
+
"*.pt",
|
| 126 |
+
"weights/*.pt",
|
| 127 |
+
],
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Licensing/attribution
|
| 131 |
+
for fn in ("LICENSE", "THIRD_PARTY_NOTICES.md"):
|
| 132 |
+
if (ROOT / fn).exists():
|
| 133 |
+
copy_file(ROOT / fn, out_dir / fn)
|
| 134 |
+
|
| 135 |
+
# Space metadata
|
| 136 |
+
write_space_readme(out_dir)
|
| 137 |
+
write_space_requirements(out_dir)
|
| 138 |
+
write_space_packages(out_dir)
|
| 139 |
+
write_lfs_gitattributes(out_dir)
|
| 140 |
+
|
| 141 |
+
print("✅ Created Space bundle at:", out_dir)
|
| 142 |
+
print("Next: upload/push the contents of that folder to your Hugging Face Space repo.")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
| 147 |
+
|
| 148 |
+
|
scripts/train_ddp.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
DDP training entrypoint.
|
| 5 |
+
|
| 6 |
+
This wraps `train_style_ddp.py` so you can run:
|
| 7 |
+
python scripts/train_ddp.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_ROOT = Path(__file__).resolve().parents[1]
|
| 17 |
+
if str(_ROOT) not in sys.path:
|
| 18 |
+
sys.path.insert(0, str(_ROOT))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main() -> None:
|
| 22 |
+
import train_style_ddp as ts
|
| 23 |
+
|
| 24 |
+
ts.run_ddp_training()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
main()
|
| 29 |
+
|
| 30 |
+
|
scripts/upgrade_proto_db_add_names.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Upgrade an existing prototype DB .pt file to include artist names (label_names).
|
| 5 |
+
|
| 6 |
+
This is useful for older prototype files that only store:
|
| 7 |
+
- centers: [N, D]
|
| 8 |
+
- labels: [N]
|
| 9 |
+
|
| 10 |
+
We infer label_names from `dataset/` folder (sorted artist directories), matching
|
| 11 |
+
`train_style_ddp.TriViewDataset` label assignment.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def infer_label_names(dataset_dir: Path) -> list[str]:
|
| 23 |
+
if not dataset_dir.exists():
|
| 24 |
+
raise FileNotFoundError(f"dataset dir not found: {dataset_dir}")
|
| 25 |
+
names = sorted([p.name for p in dataset_dir.iterdir() if p.is_dir()])
|
| 26 |
+
if not names:
|
| 27 |
+
raise RuntimeError(f"No artist folders found under: {dataset_dir}")
|
| 28 |
+
return names
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main() -> None:
|
| 32 |
+
p = argparse.ArgumentParser(description="Add label_names to an existing prototype DB .pt")
|
| 33 |
+
p.add_argument("--in", dest="in_path", required=True, help="Input .pt prototype file")
|
| 34 |
+
p.add_argument("--out", dest="out_path", default=None, help="Output .pt (default: overwrite input)")
|
| 35 |
+
p.add_argument("--dataset-dir", type=str, default="dataset", help="Dataset root to infer artist names from")
|
| 36 |
+
args = p.parse_args()
|
| 37 |
+
|
| 38 |
+
in_path = Path(args.in_path)
|
| 39 |
+
out_path = Path(args.out_path) if args.out_path else in_path
|
| 40 |
+
dataset_dir = Path(args.dataset_dir)
|
| 41 |
+
|
| 42 |
+
obj = torch.load(str(in_path), map_location="cpu")
|
| 43 |
+
if not isinstance(obj, dict) or "centers" not in obj or "labels" not in obj:
|
| 44 |
+
raise ValueError("Unsupported prototype file format (expected dict with centers+labels).")
|
| 45 |
+
|
| 46 |
+
if "label_names" in obj and isinstance(obj["label_names"], list) and obj["label_names"]:
|
| 47 |
+
print("label_names already present; nothing to do.")
|
| 48 |
+
if out_path != in_path:
|
| 49 |
+
torch.save(obj, str(out_path))
|
| 50 |
+
print("saved:", out_path)
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
label_names = infer_label_names(dataset_dir)
|
| 54 |
+
obj["label_names"] = label_names
|
| 55 |
+
|
| 56 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
torch.save(obj, str(out_path))
|
| 58 |
+
print("saved:", out_path)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
| 63 |
+
|
| 64 |
+
|