File size: 3,911 Bytes
c37e95b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Chart2CSV extraction using Granite Vision
Extracts tabular data from chart images as CSV format
"""

import spaces
from PIL import Image
from typing import Optional
import tempfile
import os


# Global model cache
_processor = None
_model = None


def load_model():
    """Lazy-load Chart2CSV model and processor."""
    global _processor, _model

    if _processor is not None and _model is not None:
        return _processor, _model

    try:
        from transformers import AutoProcessor, AutoModelForVision2Seq
        import torch

        model_id = "ibm-granite/granite-vision-3.3-2b-chart2csv-preview"

        _processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        _model = AutoModelForVision2Seq.from_pretrained(
            model_id,
            trust_remote_code=True,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        ).to(device)

        print(f"✅ Loaded {model_id} on {device}")
        return _processor, _model

    except ImportError:
        print("⚠️ Transformers not available, using stub")
        return None, None
    except Exception as e:
        print(f"⚠️ Model load error: {e}")
        import traceback
        traceback.print_exc()
        return None, None


def _save_image_to_temp(image: Image.Image) -> str:
    """Save PIL image to temporary file and return path."""
    temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    image.save(temp_file.name)
    return temp_file.name


@spaces.GPU
def extract_csv(image: Image.Image) -> str:
    """
    Extract CSV from a chart image using Granite Vision Chart2CSV model.

    Args:
        image: PIL Image of chart/table

    Returns:
        CSV text
    """
    processor, model = load_model()

    if processor is None or model is None:
        # Stub response
        return "col1,col2,col3\nvalue1,value2,value3\nvalue4,value5,value6"

    try:
        import torch

        # Save image to temp file
        image_path = _save_image_to_temp(image)

        try:
            # Prepare conversation with chart extraction prompt
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "url": image_path},  # Use file path
                        {"type": "text", "text": "Extract the data from this chart as CSV format. Return only the CSV data without explanation."},
                    ],
                }
            ]

            # Apply chat template and process
            inputs = processor.apply_chat_template(
                conversation,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt"
            )

            # Determine device
            device = "cuda" if torch.cuda.is_available() else "cpu"
            inputs = {k: v.to(device) for k, v in inputs.items()}

            # Generate CSV
            with torch.inference_mode():
                outputs = model.generate(**inputs, max_new_tokens=1024)

            # Decode response
            full_response = processor.decode(outputs[0], skip_special_tokens=True)

            # Extract just the CSV data (remove chat template)
            if "<|assistant|>" in full_response:
                csv_text = full_response.split("<|assistant|>")[-1].strip()
            else:
                csv_text = full_response

            return csv_text

        finally:
            # Clean up temp file
            if os.path.exists(image_path):
                try:
                    os.unlink(image_path)
                except Exception:
                    pass

    except Exception as e:
        import traceback
        traceback.print_exc()
        return f"❌ Error: {str(e)}"