File size: 2,531 Bytes
7451c3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python3
"""
F2: ChartQA visual_fact_deplot via offline DePlot (google/deplot) batch pipeline.
Falls back to placeholder when disabled, on missing images, or inference failure.
"""
from __future__ import annotations

import argparse
import json
import os
import sys

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from data_utils.chart.deplot_pipeline import enrich_entries_with_deplot
from data_utils.paths import project_path


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", required=True)
    parser.add_argument("--output", required=True)
    parser.add_argument(
        "--enabled",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Run real DePlot inference (default: enabled)",
    )
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--max-new-tokens", type=int, default=384)
    parser.add_argument(
        "--cache",
        default=project_path("data/chartqa/deplot_cache.json"),
        help="Incremental cache keyed by resolved image path",
    )
    parser.add_argument(
        "--replace-placeholder",
        action=argparse.BooleanOptionalAction,
        default=True,
    )
    parser.add_argument("--only-missing", action="store_true", default=False)
    parser.add_argument("--max-samples", type=int, default=0, help="0 = all entries")
    parser.add_argument("--model-id", default="google/deplot")
    parser.add_argument("--device", default="auto")
    args = parser.parse_args()

    with open(args.input, encoding="utf-8") as f:
        data = json.load(f)

    stats = enrich_entries_with_deplot(
        data,
        enabled=args.enabled,
        model_id=args.model_id,
        batch_size=args.batch_size,
        max_new_tokens=args.max_new_tokens,
        cache_path=args.cache,
        replace_placeholder=args.replace_placeholder,
        only_missing=args.only_missing,
        max_samples=args.max_samples,
        device=None if args.device == "auto" else args.device,
    )

    os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
    with open(args.output, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(
        f"Wrote {len(data)} records to {args.output} | "
        f"real={stats['real']} cached={stats['cached']} "
        f"placeholder={stats['placeholder']} skipped={stats['skipped']} failed={stats['failed']}"
    )


if __name__ == "__main__":
    main()