ravi-vc's picture
Update app.py
5af7ed0 verified
raw
history blame
1.25 kB
import gradio as gr
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from PIL import Image
import torch
# Load model & processor
processor = Pix2StructProcessor.from_pretrained("google/deplot")
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# def extract_chart_data(img: Image.Image):
# inputs = processor(images=img, return_tensors="pt").to(device)
# outputs = model.generate(**inputs)
# result = processor.decode(outputs[0], skip_special_tokens=True)
# return result
def extract_chart_data(img: Image.Image):
# Provide a header prompt for VQA
header_text = "Extract chart data"
inputs = processor(images=img, text=header_text, return_tensors="pt").to(device)
outputs = model.generate(**inputs)
result = processor.decode(outputs[0], skip_special_tokens=True)
return result
# Gradio interface
iface = gr.Interface(
fn=extract_chart_data,
inputs=gr.Image(type="pil"),
outputs="text",
title="DePlot Chart Data Extractor 2",
description="Upload a chart image and get its extracted data/description."
)
iface.launch()