File size: 1,675 Bytes
dfbe417
 
 
 
 
 
 
 
 
 
 
 
 
1714f2b
 
 
 
 
 
59bf645
1714f2b
59bf645
81dac08
1714f2b
 
dfbe417
c4bbe3f
dfbe417
 
 
1714f2b
dfbe417
 
 
59bf645
 
 
 
 
 
 
 
dfbe417
c4bbe3f
dfbe417
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from PIL import Image
import torch

# Load model & processor
processor = Pix2StructProcessor.from_pretrained("google/deplot")
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# def extract_chart_data(img: Image.Image):
#     inputs = processor(images=img, return_tensors="pt").to(device)
#     outputs = model.generate(**inputs)
#     result = processor.decode(outputs[0], skip_special_tokens=True)
#     return result

def extract_chart_data(img: Image.Image, header_text:str):
    # Provide a header prompt for VQA
    # header_text = "Convert this bar chart into a table of quarter and cost values"
    print("Header text:", header_text)
    
    inputs = processor(images=img, text=header_text, return_tensors="pt").to(device)
    outputs = model.generate(**inputs)
    print(outputs)
    result = processor.decode(outputs[0], skip_special_tokens=True)
    return result


# Gradio interface
iface = gr.Interface(
    fn=extract_chart_data,
    # inputs=gr.Image(type="pil"),
    inputs=[
        gr.Image(type="pil", label="Upload Chart Image"),  # Image input
        gr.Textbox(
            label="Instruction / Header Text",
            placeholder="Enter instructions like 'Convert this bar chart into a table of quarter and cost values'"
        )  # Text input
    ],
    outputs="text",
    title="DePlot Chart Data Extractor 2",
    description="Upload a chart image and get its extracted data/description."
)

iface.launch()