File size: 1,675 Bytes
dfbe417 1714f2b 59bf645 1714f2b 59bf645 81dac08 1714f2b dfbe417 c4bbe3f dfbe417 1714f2b dfbe417 59bf645 dfbe417 c4bbe3f dfbe417 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | import gradio as gr
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from PIL import Image
import torch
# Load model & processor
processor = Pix2StructProcessor.from_pretrained("google/deplot")
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# def extract_chart_data(img: Image.Image):
# inputs = processor(images=img, return_tensors="pt").to(device)
# outputs = model.generate(**inputs)
# result = processor.decode(outputs[0], skip_special_tokens=True)
# return result
def extract_chart_data(img: Image.Image, header_text:str):
# Provide a header prompt for VQA
# header_text = "Convert this bar chart into a table of quarter and cost values"
print("Header text:", header_text)
inputs = processor(images=img, text=header_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs)
print(outputs)
result = processor.decode(outputs[0], skip_special_tokens=True)
return result
# Gradio interface
iface = gr.Interface(
fn=extract_chart_data,
# inputs=gr.Image(type="pil"),
inputs=[
gr.Image(type="pil", label="Upload Chart Image"), # Image input
gr.Textbox(
label="Instruction / Header Text",
placeholder="Enter instructions like 'Convert this bar chart into a table of quarter and cost values'"
) # Text input
],
outputs="text",
title="DePlot Chart Data Extractor 2",
description="Upload a chart image and get its extracted data/description."
)
iface.launch()
|