DocAI / src /infer_chart2csv.py
Pengyuan Li
Add ZeroGPU support for DocAI demo on HuggingFace Spaces
c37e95b
"""
Chart2CSV extraction using Granite Vision
Extracts tabular data from chart images as CSV format
"""
import spaces
from PIL import Image
from typing import Optional
import tempfile
import os
# Global model cache
_processor = None
_model = None
def load_model():
"""Lazy-load Chart2CSV model and processor."""
global _processor, _model
if _processor is not None and _model is not None:
return _processor, _model
try:
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
model_id = "ibm-granite/granite-vision-3.3-2b-chart2csv-preview"
_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
_model = AutoModelForVision2Seq.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(device)
print(f"✅ Loaded {model_id} on {device}")
return _processor, _model
except ImportError:
print("⚠️ Transformers not available, using stub")
return None, None
except Exception as e:
print(f"⚠️ Model load error: {e}")
import traceback
traceback.print_exc()
return None, None
def _save_image_to_temp(image: Image.Image) -> str:
"""Save PIL image to temporary file and return path."""
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
image.save(temp_file.name)
return temp_file.name
@spaces.GPU
def extract_csv(image: Image.Image) -> str:
"""
Extract CSV from a chart image using Granite Vision Chart2CSV model.
Args:
image: PIL Image of chart/table
Returns:
CSV text
"""
processor, model = load_model()
if processor is None or model is None:
# Stub response
return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"
try:
import torch
# Save image to temp file
image_path = _save_image_to_temp(image)
try:
# Prepare conversation with chart extraction prompt
conversation = [
{
"role": "user",
"content": [
{"type": "image", "url": image_path}, # Use file path
{"type": "text", "text": "Extract the data from this chart as CSV format. Return only the CSV data without explanation."},
],
}
]
# Apply chat template and process
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
)
# Determine device
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate CSV
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=1024)
# Decode response
full_response = processor.decode(outputs[0], skip_special_tokens=True)
# Extract just the CSV data (remove chat template)
if "<|assistant|>" in full_response:
csv_text = full_response.split("<|assistant|>")[-1].strip()
else:
csv_text = full_response
return csv_text
finally:
# Clean up temp file
if os.path.exists(image_path):
try:
os.unlink(image_path)
except Exception:
pass
except Exception as e:
import traceback
traceback.print_exc()
return f"❌ Error: {str(e)}"