| import gradio as gr |
| import torch |
| from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration |
| from PIL import Image, ImageEnhance |
| import pandas as pd |
| import io |
| import re |
|
|
| |
| MODEL_NAME = "google/deplot" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| print(f"Loading DePlot model on {DEVICE}...") |
| try: |
| processor = Pix2StructProcessor.from_pretrained(MODEL_NAME) |
| model = Pix2StructForConditionalGeneration.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 |
| ).to(DEVICE) |
| print("β
DePlot model loaded successfully!") |
| except Exception as e: |
| print(f"β Failed to load model: {e}") |
| raise |
|
|
| def preprocess_image(image): |
| """Enhance image quality for better data extraction.""" |
| if image is None: |
| return None |
| |
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| |
| enhancer = ImageEnhance.Contrast(image) |
| image = enhancer.enhance(1.2) |
| |
| enhancer = ImageEnhance.Sharpness(image) |
| image = enhancer.enhance(1.3) |
| |
| |
| if image.width < 512: |
| new_width = 512 |
| new_height = int(512 * image.height / image.width) |
| image = image.resize((new_width, new_height), Image.LANCZOS) |
| |
| return image |
|
|
| def parse_deplot_output(raw_text): |
| """Parse DePlot output into structured table format.""" |
| if not raw_text or raw_text.strip() == "": |
| return None, "No data extracted from the image." |
| |
| lines = [line.strip() for line in raw_text.split('\n') if line.strip()] |
| |
| if not lines: |
| return None, "Empty output from model." |
| |
| |
| if '|' in raw_text: |
| try: |
| |
| table_data = [] |
| headers = None |
| |
| for line in lines: |
| if '|' in line: |
| cells = [cell.strip() for cell in line.split('|')] |
| |
| cells = [cell for cell in cells if cell] |
| |
| if headers is None: |
| headers = cells |
| else: |
| table_data.append(cells) |
| |
| if headers and table_data: |
| |
| max_cols = max(len(headers), max(len(row) for row in table_data) if table_data else 0) |
| |
| |
| while len(headers) < max_cols: |
| headers.append(f"Column_{len(headers)+1}") |
| |
| |
| for row in table_data: |
| while len(row) < max_cols: |
| row.append("") |
| |
| df = pd.DataFrame(table_data, columns=headers[:max_cols]) |
| return df, f"β
Successfully extracted table with {len(df)} rows and {len(df.columns)} columns." |
| |
| except Exception as e: |
| return None, f"Error parsing table format: {str(e)}" |
| |
| |
| try: |
| |
| data_dict = {} |
| for line in lines: |
| if ':' in line: |
| parts = line.split(':', 1) |
| if len(parts) == 2: |
| key = parts[0].strip() |
| value = parts[1].strip() |
| data_dict[key] = value |
| |
| if data_dict: |
| df = pd.DataFrame(list(data_dict.items()), columns=['Category', 'Value']) |
| return df, f"β
Extracted {len(df)} key-value pairs." |
| |
| except Exception as e: |
| return None, f"Error parsing key-value format: {str(e)}" |
| |
| |
| df = pd.DataFrame([raw_text], columns=['Extracted_Text']) |
| return df, "β οΈ Could not parse into structured format. Showing raw extracted text." |
|
|
| def extract_table_from_chart(image, prompt_type="default"): |
| """Extract data table from chart image using DePlot.""" |
| |
| if image is None: |
| return None, "Please upload an image.", "" |
| |
| try: |
| |
| processed_image = preprocess_image(image) |
| |
| |
| prompts = { |
| "default": "Generate underlying data table of the figure below:", |
| "detailed": "Extract all data points and create a comprehensive table from this chart:", |
| "summary": "Summarize the key data from this chart in table format:", |
| } |
| |
| prompt = prompts.get(prompt_type, prompts["default"]) |
| |
| |
| inputs = processor( |
| images=processed_image, |
| text=prompt, |
| return_tensors="pt" |
| ).to(DEVICE) |
| |
| |
| with torch.no_grad(): |
| generated_ids = model.generate( |
| **inputs, |
| max_new_tokens=512, |
| do_sample=False, |
| num_beams=4, |
| temperature=0.0, |
| pad_token_id=processor.tokenizer.pad_token_id, |
| eos_token_id=processor.tokenizer.eos_token_id, |
| early_stopping=True |
| ) |
| |
| |
| generated_text = processor.decode(generated_ids[0], skip_special_tokens=True) |
| |
| |
| clean_text = generated_text.replace(prompt, "").strip() |
| |
| |
| if DEVICE == "cuda": |
| torch.cuda.empty_cache() |
| |
| |
| df, status_msg = parse_deplot_output(clean_text) |
| |
| return df, status_msg, clean_text |
| |
| except Exception as e: |
| |
| if DEVICE == "cuda": |
| torch.cuda.empty_cache() |
| return None, f"β Error: {str(e)}", "" |
|
|
| |
| def create_interface(): |
| with gr.Blocks( |
| title="DePlot: Chart Data Extraction", |
| theme=gr.themes.Soft(), |
| css=""" |
| .gradio-container { |
| max-width: 1200px !important; |
| } |
| """ |
| ) as interface: |
| |
| gr.Markdown(""" |
| # π DePlot Chart Data Extractor |
| |
| Upload a chart image (bar chart, line chart, pie chart, etc.) and extract the underlying data table. |
| |
| **Supported formats:** PNG, JPG, JPEG, GIF, BMP |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image( |
| type="pil", |
| label="π Upload Chart Image", |
| height=400 |
| ) |
| |
| prompt_type = gr.Radio( |
| choices=["default", "detailed", "summary"], |
| value="default", |
| label="π― Extraction Mode", |
| info="Choose how detailed the extraction should be" |
| ) |
| |
| extract_btn = gr.Button("π Extract Data Table", variant="primary", size="lg") |
| |
| with gr.Column(scale=2): |
| status_output = gr.Textbox( |
| label="π Status", |
| interactive=False, |
| max_lines=2 |
| ) |
| |
| table_output = gr.Dataframe( |
| label="π Extracted Data Table", |
| interactive=False, |
| wrap=True |
| ) |
| |
| with gr.Accordion("π Raw Extracted Text", open=False): |
| raw_output = gr.Textbox( |
| label="Raw DePlot Output", |
| interactive=False, |
| max_lines=10, |
| show_copy_button=True |
| ) |
| |
| |
| gr.Markdown("### πΈ Try these example charts:") |
| gr.Examples( |
| examples=[ |
| ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/deplot_demo.png"], |
| ], |
| inputs=image_input, |
| label="Click to load example" |
| ) |
| |
| |
| extract_btn.click( |
| fn=extract_table_from_chart, |
| inputs=[image_input, prompt_type], |
| outputs=[table_output, status_output, raw_output], |
| show_progress=True |
| ) |
| |
| |
| image_input.change( |
| fn=lambda img: extract_table_from_chart(img, "default") if img else (None, "Please upload an image.", ""), |
| inputs=image_input, |
| outputs=[table_output, status_output, raw_output], |
| show_progress=True |
| ) |
| |
| return interface |
|
|
| |
| if __name__ == "__main__": |
| interface = create_interface() |
| interface.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True, |
| show_tips=True |
| ) |