File size: 1,755 Bytes
8153a62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import annotations

import argparse
import shutil
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class Config:
    outputs_dir: Path
    model_dir: Path


def _parse_args() -> Config:
    p = argparse.ArgumentParser(description="Prepare model/ folder from an outputs/ training directory.")
    p.add_argument("--outputs-dir", default="outputs", help="Training output directory (from mailsort.train).")
    p.add_argument("--model-dir", default="model", help="Target folder to commit/push to Hugging Face.")
    a = p.parse_args()
    return Config(outputs_dir=Path(a.outputs_dir), model_dir=Path(a.model_dir))


def main() -> int:
    cfg = _parse_args()

    if not cfg.outputs_dir.exists():
        raise SystemExit(f"outputs-dir not found: {cfg.outputs_dir}")

    cfg.model_dir.mkdir(parents=True, exist_ok=True)

    # clean target (keep it explicit and predictable)
    for p in cfg.model_dir.iterdir():
        if p.is_dir():
            shutil.rmtree(p)
        else:
            p.unlink()

    # Copy only final artifacts (root files), ignore trainer checkpoints.
    for p in cfg.outputs_dir.iterdir():
        if p.is_dir():
            # ignore checkpoint-* dirs
            continue
        shutil.copy2(p, cfg.model_dir / p.name)

    # sanity: expected minimum files
    expected_any = [
        "config.json",
        "tokenizer.json",
        "tokenizer_config.json",
    ]
    missing = [n for n in expected_any if not (cfg.model_dir / n).exists()]
    if missing:
        raise SystemExit(f"Missing expected files in {cfg.model_dir}: {missing}")

    print(f"Prepared {cfg.model_dir} from {cfg.outputs_dir}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())