ravi-vc's picture
Update app.py
780b81e verified
import gradio as gr
import torch
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
from PIL import Image, ImageEnhance
import pandas as pd
import io
import re
# Configuration
MODEL_NAME = "google/deplot"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
print(f"Loading DePlot model on {DEVICE}...")
try:
processor = Pix2StructProcessor.from_pretrained(MODEL_NAME)
model = Pix2StructForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE)
print("βœ… DePlot model loaded successfully!")
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise
def preprocess_image(image):
"""Enhance image quality for better data extraction."""
if image is None:
return None
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Enhance contrast and sharpness
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(1.2)
enhancer = ImageEnhance.Sharpness(image)
image = enhancer.enhance(1.3)
# Resize if too small (minimum 512px width for better results)
if image.width < 512:
new_width = 512
new_height = int(512 * image.height / image.width)
image = image.resize((new_width, new_height), Image.LANCZOS)
return image
def parse_deplot_output(raw_text):
"""Parse DePlot output into structured table format."""
if not raw_text or raw_text.strip() == "":
return None, "No data extracted from the image."
lines = [line.strip() for line in raw_text.split('\n') if line.strip()]
if not lines:
return None, "Empty output from model."
# Try to parse as pipe-separated table
if '|' in raw_text:
try:
# Split by lines and parse each line
table_data = []
headers = None
for line in lines:
if '|' in line:
cells = [cell.strip() for cell in line.split('|')]
# Remove empty cells at start/end
cells = [cell for cell in cells if cell]
if headers is None:
headers = cells
else:
table_data.append(cells)
if headers and table_data:
# Create DataFrame
max_cols = max(len(headers), max(len(row) for row in table_data) if table_data else 0)
# Pad headers if needed
while len(headers) < max_cols:
headers.append(f"Column_{len(headers)+1}")
# Pad rows if needed
for row in table_data:
while len(row) < max_cols:
row.append("")
df = pd.DataFrame(table_data, columns=headers[:max_cols])
return df, f"βœ… Successfully extracted table with {len(df)} rows and {len(df.columns)} columns."
except Exception as e:
return None, f"Error parsing table format: {str(e)}"
# If not pipe-separated, try to parse as key-value pairs or simple list
try:
# Look for patterns like "Category: Value" or "Item | Value"
data_dict = {}
for line in lines:
if ':' in line:
parts = line.split(':', 1)
if len(parts) == 2:
key = parts[0].strip()
value = parts[1].strip()
data_dict[key] = value
if data_dict:
df = pd.DataFrame(list(data_dict.items()), columns=['Category', 'Value'])
return df, f"βœ… Extracted {len(df)} key-value pairs."
except Exception as e:
return None, f"Error parsing key-value format: {str(e)}"
# If all parsing fails, return raw text in a single-column table
df = pd.DataFrame([raw_text], columns=['Extracted_Text'])
return df, "⚠️ Could not parse into structured format. Showing raw extracted text."
def extract_table_from_chart(image, prompt_type="default"):
"""Extract data table from chart image using DePlot."""
if image is None:
return None, "Please upload an image.", ""
try:
# Preprocess image
processed_image = preprocess_image(image)
# Define prompts
prompts = {
"default": "Generate underlying data table of the figure below:",
"detailed": "Extract all data points and create a comprehensive table from this chart:",
"summary": "Summarize the key data from this chart in table format:",
}
prompt = prompts.get(prompt_type, prompts["default"])
# Prepare inputs
inputs = processor(
images=processed_image,
text=prompt,
return_tensors="pt"
).to(DEVICE)
# Generate output
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
num_beams=4,
temperature=0.0,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
early_stopping=True
)
# Decode output
generated_text = processor.decode(generated_ids[0], skip_special_tokens=True)
# Remove the prompt from output
clean_text = generated_text.replace(prompt, "").strip()
# Clear GPU cache
if DEVICE == "cuda":
torch.cuda.empty_cache()
# Parse the output
df, status_msg = parse_deplot_output(clean_text)
return df, status_msg, clean_text
except Exception as e:
# Clear GPU cache on error
if DEVICE == "cuda":
torch.cuda.empty_cache()
return None, f"❌ Error: {str(e)}", ""
# Create Gradio interface
def create_interface():
with gr.Blocks(
title="DePlot: Chart Data Extraction",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
"""
) as interface:
gr.Markdown("""
# πŸ“Š DePlot Chart Data Extractor
Upload a chart image (bar chart, line chart, pie chart, etc.) and extract the underlying data table.
**Supported formats:** PNG, JPG, JPEG, GIF, BMP
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="πŸ“ Upload Chart Image",
height=400
)
prompt_type = gr.Radio(
choices=["default", "detailed", "summary"],
value="default",
label="🎯 Extraction Mode",
info="Choose how detailed the extraction should be"
)
extract_btn = gr.Button("πŸš€ Extract Data Table", variant="primary", size="lg")
with gr.Column(scale=2):
status_output = gr.Textbox(
label="πŸ“‹ Status",
interactive=False,
max_lines=2
)
table_output = gr.Dataframe(
label="πŸ“Š Extracted Data Table",
interactive=False,
wrap=True
)
with gr.Accordion("πŸ” Raw Extracted Text", open=False):
raw_output = gr.Textbox(
label="Raw DePlot Output",
interactive=False,
max_lines=10,
show_copy_button=True
)
# Examples
gr.Markdown("### πŸ“Έ Try these example charts:")
gr.Examples(
examples=[
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/deplot_demo.png"],
],
inputs=image_input,
label="Click to load example"
)
# Event handlers
extract_btn.click(
fn=extract_table_from_chart,
inputs=[image_input, prompt_type],
outputs=[table_output, status_output, raw_output],
show_progress=True
)
# Auto-extract on image upload
image_input.change(
fn=lambda img: extract_table_from_chart(img, "default") if img else (None, "Please upload an image.", ""),
inputs=image_input,
outputs=[table_output, status_output, raw_output],
show_progress=True
)
return interface
# Launch the app
if __name__ == "__main__":
interface = create_interface()
interface.launch(
server_name="0.0.0.0", # For Hugging Face Spaces
server_port=7860, # Standard port for HF Spaces
share=False, # Don't create public link
show_error=True,
show_tips=True
)