import gradio as gr import torch from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration from PIL import Image, ImageEnhance import pandas as pd import io import re # Configuration MODEL_NAME = "google/deplot" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Load model and processor print(f"Loading DePlot model on {DEVICE}...") try: processor = Pix2StructProcessor.from_pretrained(MODEL_NAME) model = Pix2StructForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 ).to(DEVICE) print("✅ DePlot model loaded successfully!") except Exception as e: print(f"❌ Failed to load model: {e}") raise def preprocess_image(image): """Enhance image quality for better data extraction.""" if image is None: return None # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Enhance contrast and sharpness enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(1.2) enhancer = ImageEnhance.Sharpness(image) image = enhancer.enhance(1.3) # Resize if too small (minimum 512px width for better results) if image.width < 512: new_width = 512 new_height = int(512 * image.height / image.width) image = image.resize((new_width, new_height), Image.LANCZOS) return image def parse_deplot_output(raw_text): """Parse DePlot output into structured table format.""" if not raw_text or raw_text.strip() == "": return None, "No data extracted from the image." lines = [line.strip() for line in raw_text.split('\n') if line.strip()] if not lines: return None, "Empty output from model." # Try to parse as pipe-separated table if '|' in raw_text: try: # Split by lines and parse each line table_data = [] headers = None for line in lines: if '|' in line: cells = [cell.strip() for cell in line.split('|')] # Remove empty cells at start/end cells = [cell for cell in cells if cell] if headers is None: headers = cells else: table_data.append(cells) if headers and table_data: # Create DataFrame max_cols = max(len(headers), max(len(row) for row in table_data) if table_data else 0) # Pad headers if needed while len(headers) < max_cols: headers.append(f"Column_{len(headers)+1}") # Pad rows if needed for row in table_data: while len(row) < max_cols: row.append("") df = pd.DataFrame(table_data, columns=headers[:max_cols]) return df, f"✅ Successfully extracted table with {len(df)} rows and {len(df.columns)} columns." except Exception as e: return None, f"Error parsing table format: {str(e)}" # If not pipe-separated, try to parse as key-value pairs or simple list try: # Look for patterns like "Category: Value" or "Item | Value" data_dict = {} for line in lines: if ':' in line: parts = line.split(':', 1) if len(parts) == 2: key = parts[0].strip() value = parts[1].strip() data_dict[key] = value if data_dict: df = pd.DataFrame(list(data_dict.items()), columns=['Category', 'Value']) return df, f"✅ Extracted {len(df)} key-value pairs." except Exception as e: return None, f"Error parsing key-value format: {str(e)}" # If all parsing fails, return raw text in a single-column table df = pd.DataFrame([raw_text], columns=['Extracted_Text']) return df, "⚠️ Could not parse into structured format. Showing raw extracted text." def extract_table_from_chart(image, prompt_type="default"): """Extract data table from chart image using DePlot.""" if image is None: return None, "Please upload an image.", "" try: # Preprocess image processed_image = preprocess_image(image) # Define prompts prompts = { "default": "Generate underlying data table of the figure below:", "detailed": "Extract all data points and create a comprehensive table from this chart:", "summary": "Summarize the key data from this chart in table format:", } prompt = prompts.get(prompt_type, prompts["default"]) # Prepare inputs inputs = processor( images=processed_image, text=prompt, return_tensors="pt" ).to(DEVICE) # Generate output with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=512, do_sample=False, num_beams=4, temperature=0.0, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, early_stopping=True ) # Decode output generated_text = processor.decode(generated_ids[0], skip_special_tokens=True) # Remove the prompt from output clean_text = generated_text.replace(prompt, "").strip() # Clear GPU cache if DEVICE == "cuda": torch.cuda.empty_cache() # Parse the output df, status_msg = parse_deplot_output(clean_text) return df, status_msg, clean_text except Exception as e: # Clear GPU cache on error if DEVICE == "cuda": torch.cuda.empty_cache() return None, f"❌ Error: {str(e)}", "" # Create Gradio interface def create_interface(): with gr.Blocks( title="DePlot: Chart Data Extraction", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } """ ) as interface: gr.Markdown(""" # 📊 DePlot Chart Data Extractor Upload a chart image (bar chart, line chart, pie chart, etc.) and extract the underlying data table. **Supported formats:** PNG, JPG, JPEG, GIF, BMP """) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="📁 Upload Chart Image", height=400 ) prompt_type = gr.Radio( choices=["default", "detailed", "summary"], value="default", label="🎯 Extraction Mode", info="Choose how detailed the extraction should be" ) extract_btn = gr.Button("🚀 Extract Data Table", variant="primary", size="lg") with gr.Column(scale=2): status_output = gr.Textbox( label="📋 Status", interactive=False, max_lines=2 ) table_output = gr.Dataframe( label="📊 Extracted Data Table", interactive=False, wrap=True ) with gr.Accordion("🔍 Raw Extracted Text", open=False): raw_output = gr.Textbox( label="Raw DePlot Output", interactive=False, max_lines=10, show_copy_button=True ) # Examples gr.Markdown("### 📸 Try these example charts:") gr.Examples( examples=[ ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/deplot_demo.png"], ], inputs=image_input, label="Click to load example" ) # Event handlers extract_btn.click( fn=extract_table_from_chart, inputs=[image_input, prompt_type], outputs=[table_output, status_output, raw_output], show_progress=True ) # Auto-extract on image upload image_input.change( fn=lambda img: extract_table_from_chart(img, "default") if img else (None, "Please upload an image.", ""), inputs=image_input, outputs=[table_output, status_output, raw_output], show_progress=True ) return interface # Launch the app if __name__ == "__main__": interface = create_interface() interface.launch( server_name="0.0.0.0", # For Hugging Face Spaces server_port=7860, # Standard port for HF Spaces share=False, # Don't create public link show_error=True, show_tips=True )