File size: 4,107 Bytes
f1fb42f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Chart-to-CSV extraction using Granite Vision.

Converts chart images to tabular CSV data using ibm-granite/granite-vision-4.1-4b.
"""

import threading
from collections.abc import Generator
from PIL import Image
from typing import Any

import torch
from PIL import Image
from transformers import TextIteratorStreamer
from model_loader import load_model, use_api_mode, use_mlx_mode


def extract_csv(image: Image.Image) -> str:
    """Extract CSV data from a chart image using Granite Vision.

    Args:
        image: PIL Image of a chart or table.

    Returns:
        CSV-formatted text extracted from the chart.
    """
    if use_api_mode():
        from infer_api import extract_csv_api
        return extract_csv_api(image)

    if use_mlx_mode():
        from infer_mlx import extract_csv_mlx
        return extract_csv_mlx(image)

    processor, model = load_model()

    if processor is None or model is None:
        return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"

    try:
        import torch

        image = image.convert("RGB")
        conversation = [{"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": "<chart2csv>"},
        ]}]
        text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)

        max_new_tokens = 4096
        with torch.inference_mode():
            outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)

        gen = outputs[0, inputs["input_ids"].shape[1]:]
        result = processor.decode(gen, skip_special_tokens=True)
        if len(gen) >= max_new_tokens:
            result += "\n\n[Max token limit reached — response may be truncated]"
        return result

    except Exception as e:  # noqa: BLE001
        import traceback

        traceback.print_exc()
        return f"Error: {e!s}"


def _run_generate(model: Any, generation_kwargs: dict[str, Any]) -> None:
    """Run model.generate in a thread (used by the streaming variant)."""
    with torch.inference_mode():
        model.generate(**generation_kwargs)


def extract_csv_stream(image: Image.Image) -> Generator[str, None, None]:
    """Stream CSV extraction token-by-token from a chart image.
    Same interface as extract_csv() but yields tokens incrementally.
    Args:
        image: PIL Image of a chart or table.
    Yields:
        Token strings as they are generated by the model.
    """
    if use_api_mode():
        from infer_api import extract_csv_stream_api
        yield from extract_csv_stream_api(image)
        return

    if use_mlx_mode():
        from infer_mlx import extract_csv_stream_mlx
        yield from extract_csv_stream_mlx(image)
        return

    processor, model = load_model()

    if processor is None or model is None:
        yield "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"
        return

    try:
        image = image.convert("RGB")
        conversation = [{"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": "<chart2csv>"},
        ]}]
        text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)

        streamer = TextIteratorStreamer(
            processor.tokenizer, skip_prompt=True, skip_special_tokens=True,
        )
        generation_kwargs = {
            **inputs,
            "max_new_tokens": 4096,
            "use_cache": True,
            "streamer": streamer,
        }

        thread = threading.Thread(target=_run_generate, args=(model, generation_kwargs))
        thread.start()

        accumulated = ""
        for token_text in streamer:
            if token_text:
                accumulated += token_text
                yield accumulated

        thread.join()

    except Exception as e:  # noqa: BLE001
        import traceback

        traceback.print_exc()
        yield f"Error: {e!s}"