| | """ |
| | Chart2CSV extraction using Granite Vision |
| | Extracts tabular data from chart images as CSV format |
| | """ |
| |
|
| | import spaces |
| | from PIL import Image |
| | from typing import Optional |
| | import tempfile |
| | import os |
| |
|
| |
|
| | |
| | _processor = None |
| | _model = None |
| |
|
| |
|
| | def load_model(): |
| | """Lazy-load Chart2CSV model and processor.""" |
| | global _processor, _model |
| |
|
| | if _processor is not None and _model is not None: |
| | return _processor, _model |
| |
|
| | try: |
| | from transformers import AutoProcessor, AutoModelForVision2Seq |
| | import torch |
| |
|
| | model_id = "ibm-granite/granite-vision-3.3-2b-chart2csv-preview" |
| |
|
| | _processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | _model = AutoModelForVision2Seq.from_pretrained( |
| | model_id, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | ).to(device) |
| |
|
| | print(f"✅ Loaded {model_id} on {device}") |
| | return _processor, _model |
| |
|
| | except ImportError: |
| | print("⚠️ Transformers not available, using stub") |
| | return None, None |
| | except Exception as e: |
| | print(f"⚠️ Model load error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return None, None |
| |
|
| |
|
| | def _save_image_to_temp(image: Image.Image) -> str: |
| | """Save PIL image to temporary file and return path.""" |
| | temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
| | image.save(temp_file.name) |
| | return temp_file.name |
| |
|
| |
|
| | @spaces.GPU |
| | def extract_csv(image: Image.Image) -> str: |
| | """ |
| | Extract CSV from a chart image using Granite Vision Chart2CSV model. |
| | |
| | Args: |
| | image: PIL Image of chart/table |
| | |
| | Returns: |
| | CSV text |
| | """ |
| | processor, model = load_model() |
| |
|
| | if processor is None or model is None: |
| | |
| | return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6" |
| |
|
| | try: |
| | import torch |
| |
|
| | |
| | image_path = _save_image_to_temp(image) |
| |
|
| | try: |
| | |
| | conversation = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "url": image_path}, |
| | {"type": "text", "text": "Extract the data from this chart as CSV format. Return only the CSV data without explanation."}, |
| | ], |
| | } |
| | ] |
| |
|
| | |
| | inputs = processor.apply_chat_template( |
| | conversation, |
| | add_generation_prompt=True, |
| | tokenize=True, |
| | return_dict=True, |
| | return_tensors="pt" |
| | ) |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
|
| | |
| | with torch.inference_mode(): |
| | outputs = model.generate(**inputs, max_new_tokens=1024) |
| |
|
| | |
| | full_response = processor.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | |
| | if "<|assistant|>" in full_response: |
| | csv_text = full_response.split("<|assistant|>")[-1].strip() |
| | else: |
| | csv_text = full_response |
| |
|
| | return csv_text |
| |
|
| | finally: |
| | |
| | if os.path.exists(image_path): |
| | try: |
| | os.unlink(image_path) |
| | except Exception: |
| | pass |
| |
|
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return f"❌ Error: {str(e)}" |
| |
|