import gradio as gr from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration from PIL import Image import torch # Load model & processor processor = Pix2StructProcessor.from_pretrained("google/deplot") model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot") # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) # def extract_chart_data(img: Image.Image): # inputs = processor(images=img, return_tensors="pt").to(device) # outputs = model.generate(**inputs) # result = processor.decode(outputs[0], skip_special_tokens=True) # return result def extract_chart_data(img: Image.Image, header_text:str): # Provide a header prompt for VQA # header_text = "Convert this bar chart into a table of quarter and cost values" print("Header text:", header_text) inputs = processor(images=img, text=header_text, return_tensors="pt").to(device) outputs = model.generate(**inputs) print(outputs) result = processor.decode(outputs[0], skip_special_tokens=True) return result # Gradio interface iface = gr.Interface( fn=extract_chart_data, # inputs=gr.Image(type="pil"), inputs=[ gr.Image(type="pil", label="Upload Chart Image"), # Image input gr.Textbox( label="Instruction / Header Text", placeholder="Enter instructions like 'Convert this bar chart into a table of quarter and cost values'" ) # Text input ], outputs="text", title="DePlot Chart Data Extractor 2", description="Upload a chart image and get its extracted data/description." ) iface.launch()