File size: 3,268 Bytes
0748838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

from tqdm import tqdm

# Allow running without installation: `python infer_cli.py ...`
BUNDLE_ROOT = Path(__file__).resolve().parent
SRC_DIR = BUNDLE_ROOT / "src"
if str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

from judgment_partition_infer.infer import Predictor, default_run_dir, write_run_meta  # noqa: E402


def main() -> int:
    parser = argparse.ArgumentParser(description="judgment_partition_infer (JSONL -> JSONL)")
    parser.add_argument("--input", type=str, required=True, help="Input jsonl")
    parser.add_argument(
        "--output-root",
        type=str,
        default=None,
        help="Root output dir. Default: ./output/<timestamp>/",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Explicit output jsonl path (overrides output-root/timestamp).",
    )
    parser.add_argument("--model", type=str, default=None, help="Model checkpoint (.pt)")
    parser.add_argument("--vocab", type=str, default=None, help="Vocab json")
    parser.add_argument("--device", type=str, default="cuda", help="cuda|cpu (cuda falls back to cpu)")
    parser.add_argument("--anchor", type=str, default="auto", choices=["auto", "off"])
    parser.add_argument("--max-samples", type=int, default=None, help="Process at most N samples")
    args = parser.parse_args()

    input_path = Path(args.input)
    if not input_path.exists():
        raise FileNotFoundError(f"Missing input: {input_path}")

    output_root = Path(args.output_root) if args.output_root else (BUNDLE_ROOT / "output")
    run_dir = default_run_dir(output_root)
    run_dir.mkdir(parents=True, exist_ok=True)

    output_path = Path(args.output) if args.output else (run_dir / "predictions.jsonl")
    output_path.parent.mkdir(parents=True, exist_ok=True)

    predictor = Predictor(
        model_path=Path(args.model) if args.model else None,
        vocab_path=Path(args.vocab) if args.vocab else None,
        device=args.device,
        anchor=args.anchor,
    )

    meta = {
        "input": str(input_path),
        "output": str(output_path),
        "run_dir": str(run_dir),
        "device_requested": args.device,
        "device_used": str(predictor.torch_device),
        "anchor": args.anchor,
        "model_path": str(predictor.model_path),
        "vocab_path": str(predictor.vocab_path),
    }
    write_run_meta(run_dir / "run_meta.json", meta)

    written = 0
    with input_path.open("r", encoding="utf-8") as f_in, output_path.open("w", encoding="utf-8") as f_out:
        for line in tqdm(f_in, desc="Infer", unit="line"):
            if args.max_samples is not None and written >= args.max_samples:
                break
            line = line.strip()
            if not line:
                continue
            try:
                record = json.loads(line)
            except Exception:
                continue
            out = predictor.predict_record(record)
            f_out.write(json.dumps(out, ensure_ascii=False) + "\n")
            written += 1

    print(f"[DONE] samples={written} -> {output_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())