ravi-vc's picture
Update app.py
c4bbe3f verified
raw
history blame
1.68 kB
import gradio as gr
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from PIL import Image
import torch
# Load model & processor
processor = Pix2StructProcessor.from_pretrained("google/deplot")
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# def extract_chart_data(img: Image.Image):
# inputs = processor(images=img, return_tensors="pt").to(device)
# outputs = model.generate(**inputs)
# result = processor.decode(outputs[0], skip_special_tokens=True)
# return result
def extract_chart_data(img: Image.Image, header_text:str):
# Provide a header prompt for VQA
# header_text = "Convert this bar chart into a table of quarter and cost values"
print("Header text:", header_text)
inputs = processor(images=img, text=header_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs)
print(outputs)
result = processor.decode(outputs[0], skip_special_tokens=True)
return result
# Gradio interface
iface = gr.Interface(
fn=extract_chart_data,
# inputs=gr.Image(type="pil"),
inputs=[
gr.Image(type="pil", label="Upload Chart Image"), # Image input
gr.Textbox(
label="Instruction / Header Text",
placeholder="Enter instructions like 'Convert this bar chart into a table of quarter and cost values'"
) # Text input
],
outputs="text",
title="DePlot Chart Data Extractor 2",
description="Upload a chart image and get its extracted data/description."
)
iface.launch()