File size: 4,459 Bytes
534c64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import json
import os
import sys
import tempfile
from unittest.mock import patch

import pytest
from PIL import Image

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from data_utils.chart.deplot_pipeline import (
    build_deplot_visual_fact,
    enrich_entries_with_deplot,
    format_deplot_for_teacher,
    has_real_deplot,
    is_deplot_placeholder,
    load_deplot_cache,
    placeholder_deplot_table,
    save_deplot_cache,
)
from opsd_utils.privileged.providers import VisualFactsProvider


def test_is_deplot_placeholder():
    ph = placeholder_deplot_table({"question": "q"})
    assert is_deplot_placeholder(ph)
    real = build_deplot_visual_fact({"question": "q"}, "A | 1\nB | 2")
    assert not is_deplot_placeholder(real)
    assert has_real_deplot(real)


def test_format_deplot_for_teacher():
    real = build_deplot_visual_fact({"question": "q"}, "Year | Value\n2020 | 10")
    assert format_deplot_for_teacher(real) == "Year | Value\n2020 | 10"
    assert format_deplot_for_teacher(placeholder_deplot_table({"question": "q"})) == ""
    assert format_deplot_for_teacher(None) == ""
    assert format_deplot_for_teacher("") == ""


def test_visual_facts_provider_skips_placeholder_and_missing():
    provider = VisualFactsProvider()
    sample_ph = {
        "visual_fact_deplot": placeholder_deplot_table({"question": "q"}),
        "visual_fact_hint": "hint text",
    }
    suffix = provider.build_teacher_suffix(sample_ph)
    assert "Visual Facts - Hint" in suffix
    assert "Visual Facts - DePlot" not in suffix

    sample_none = {"visual_fact_hint": "only hint"}
    suffix2 = provider.build_teacher_suffix(sample_none)
    assert "Visual Facts - DePlot" not in suffix2
    assert "Visual Facts - Hint" in suffix2


def test_visual_facts_provider_real_deplot_table():
    provider = VisualFactsProvider()
    table = "Category | 2019 | 2020\nA | 1 | 2"
    sample = {
        "visual_fact_deplot": build_deplot_visual_fact({"question": "q"}, table),
    }
    suffix = provider.build_teacher_suffix(sample)
    assert "Visual Facts - DePlot" in suffix
    assert table in suffix
    assert '"parsed_table"' not in suffix


def test_deplot_cache_roundtrip():
    with tempfile.TemporaryDirectory() as tmp:
        path = os.path.join(tmp, "cache.json")
        save_deplot_cache(path, {"/a.png": "table text"})
        loaded = load_deplot_cache(path)
        assert loaded["/a.png"] == "table text"


def test_enrich_disabled_uses_placeholder():
    entries = [{"question": "q1", "image": "missing.png"}]
    stats = enrich_entries_with_deplot(entries, enabled=False)
    assert stats["placeholder"] == 1
    assert is_deplot_placeholder(entries[0]["visual_fact_deplot"])


def test_enrich_with_mock_runner():
    with tempfile.TemporaryDirectory() as tmp:
        img_path = os.path.join(tmp, "chart.png")
        Image.new("RGB", (32, 32)).save(img_path)
        cache_path = os.path.join(tmp, "deplot_cache.json")
        entries = [{"question": "What?", "image": img_path}]

        class _FakeRunner:
            def load(self):
                return True

            def generate_batch_with_oom_retry(self, paths, batch_size=8):
                return ["Col | Val\nA | 1" for _ in paths]

        with patch("data_utils.chart.deplot_pipeline.DePlotRunner", return_value=_FakeRunner()):
            stats = enrich_entries_with_deplot(
                entries,
                enabled=True,
                cache_path=cache_path,
            )
        assert stats["real"] == 1
        assert has_real_deplot(entries[0]["visual_fact_deplot"])
        assert "Col | Val" in format_deplot_for_teacher(entries[0]["visual_fact_deplot"])
        assert os.path.isfile(cache_path)


def test_build_script_disabled(tmp_path):
    import subprocess

    root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    inp = tmp_path / "in.json"
    out = tmp_path / "out.json"
    inp.write_text(json.dumps([{"question": "q", "hint": "h"}]), encoding="utf-8")
    subprocess.run(
        [
            sys.executable,
            os.path.join(root, "scripts", "build_visual_facts_chartqa_deplot.py"),
            "--input",
            str(inp),
            "--output",
            str(out),
            "--no-enabled",
        ],
        check=True,
        cwd=root,
    )
    data = json.loads(out.read_text(encoding="utf-8"))
    assert is_deplot_placeholder(data[0]["visual_fact_deplot"])