""" Chart2CSV extraction using Granite Vision Extracts tabular data from chart images as CSV format """ import spaces from PIL import Image from typing import Optional import tempfile import os # Global model cache _processor = None _model = None def load_model(): """Lazy-load Chart2CSV model and processor.""" global _processor, _model if _processor is not None and _model is not None: return _processor, _model try: from transformers import AutoProcessor, AutoModelForVision2Seq import torch model_id = "ibm-granite/granite-vision-3.3-2b-chart2csv-preview" _processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) device = "cuda" if torch.cuda.is_available() else "cpu" _model = AutoModelForVision2Seq.from_pretrained( model_id, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(device) print(f"✅ Loaded {model_id} on {device}") return _processor, _model except ImportError: print("⚠️ Transformers not available, using stub") return None, None except Exception as e: print(f"⚠️ Model load error: {e}") import traceback traceback.print_exc() return None, None def _save_image_to_temp(image: Image.Image) -> str: """Save PIL image to temporary file and return path.""" temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) image.save(temp_file.name) return temp_file.name @spaces.GPU def extract_csv(image: Image.Image) -> str: """ Extract CSV from a chart image using Granite Vision Chart2CSV model. Args: image: PIL Image of chart/table Returns: CSV text """ processor, model = load_model() if processor is None or model is None: # Stub response return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6" try: import torch # Save image to temp file image_path = _save_image_to_temp(image) try: # Prepare conversation with chart extraction prompt conversation = [ { "role": "user", "content": [ {"type": "image", "url": image_path}, # Use file path {"type": "text", "text": "Extract the data from this chart as CSV format. Return only the CSV data without explanation."}, ], } ] # Apply chat template and process inputs = processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ) # Determine device device = "cuda" if torch.cuda.is_available() else "cpu" inputs = {k: v.to(device) for k, v in inputs.items()} # Generate CSV with torch.inference_mode(): outputs = model.generate(**inputs, max_new_tokens=1024) # Decode response full_response = processor.decode(outputs[0], skip_special_tokens=True) # Extract just the CSV data (remove chat template) if "<|assistant|>" in full_response: csv_text = full_response.split("<|assistant|>")[-1].strip() else: csv_text = full_response return csv_text finally: # Clean up temp file if os.path.exists(image_path): try: os.unlink(image_path) except Exception: pass except Exception as e: import traceback traceback.print_exc() return f"❌ Error: {str(e)}"