Spaces:
Running on Zero
Running on Zero
File size: 4,107 Bytes
f1fb42f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """Chart-to-CSV extraction using Granite Vision.
Converts chart images to tabular CSV data using ibm-granite/granite-vision-4.1-4b.
"""
import threading
from collections.abc import Generator
from PIL import Image
from typing import Any
import torch
from PIL import Image
from transformers import TextIteratorStreamer
from model_loader import load_model, use_api_mode, use_mlx_mode
def extract_csv(image: Image.Image) -> str:
"""Extract CSV data from a chart image using Granite Vision.
Args:
image: PIL Image of a chart or table.
Returns:
CSV-formatted text extracted from the chart.
"""
if use_api_mode():
from infer_api import extract_csv_api
return extract_csv_api(image)
if use_mlx_mode():
from infer_mlx import extract_csv_mlx
return extract_csv_mlx(image)
processor, model = load_model()
if processor is None or model is None:
return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"
try:
import torch
image = image.convert("RGB")
conversation = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": "<chart2csv>"},
]}]
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
max_new_tokens = 4096
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
gen = outputs[0, inputs["input_ids"].shape[1]:]
result = processor.decode(gen, skip_special_tokens=True)
if len(gen) >= max_new_tokens:
result += "\n\n[Max token limit reached — response may be truncated]"
return result
except Exception as e: # noqa: BLE001
import traceback
traceback.print_exc()
return f"Error: {e!s}"
def _run_generate(model: Any, generation_kwargs: dict[str, Any]) -> None:
"""Run model.generate in a thread (used by the streaming variant)."""
with torch.inference_mode():
model.generate(**generation_kwargs)
def extract_csv_stream(image: Image.Image) -> Generator[str, None, None]:
"""Stream CSV extraction token-by-token from a chart image.
Same interface as extract_csv() but yields tokens incrementally.
Args:
image: PIL Image of a chart or table.
Yields:
Token strings as they are generated by the model.
"""
if use_api_mode():
from infer_api import extract_csv_stream_api
yield from extract_csv_stream_api(image)
return
if use_mlx_mode():
from infer_mlx import extract_csv_stream_mlx
yield from extract_csv_stream_mlx(image)
return
processor, model = load_model()
if processor is None or model is None:
yield "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"
return
try:
image = image.convert("RGB")
conversation = [{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": "<chart2csv>"},
]}]
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
processor.tokenizer, skip_prompt=True, skip_special_tokens=True,
)
generation_kwargs = {
**inputs,
"max_new_tokens": 4096,
"use_cache": True,
"streamer": streamer,
}
thread = threading.Thread(target=_run_generate, args=(model, generation_kwargs))
thread.start()
accumulated = ""
for token_text in streamer:
if token_text:
accumulated += token_text
yield accumulated
thread.join()
except Exception as e: # noqa: BLE001
import traceback
traceback.print_exc()
yield f"Error: {e!s}"
|