ravi-vc commited on
Commit
780b81e
Β·
verified Β·
1 Parent(s): f1b8462

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -178
app.py CHANGED
@@ -1,231 +1,275 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
4
- from PIL import Image
5
- import requests
6
  import io
7
  import re
8
- import pandas as pd
9
- import json
10
 
11
- # Load the DePlot model and processor
12
  MODEL_NAME = "google/deplot"
 
13
 
14
- def load_model():
15
- """Load the DePlot model and processor"""
16
- try:
17
- processor = Pix2StructProcessor.from_pretrained(MODEL_NAME)
18
- model = Pix2StructForConditionalGeneration.from_pretrained(MODEL_NAME)
19
- return processor, model
20
- except Exception as e:
21
- print(f"Error loading model: {e}")
22
- return None, None
 
 
 
23
 
24
- processor, model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def extract_chart_data(image, question="Generate underlying data table of the figure below:"):
27
- """
28
- Extract data from chart image using DePlot model
 
29
 
30
- Args:
31
- image: PIL Image or file path
32
- question: Question to ask about the chart
33
 
34
- Returns:
35
- Extracted data as text and structured format
36
- """
37
- if processor is None or model is None:
38
- return "Error: Model not loaded properly", None
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
- # Ensure image is PIL Image
42
- if isinstance(image, str):
43
- image = Image.open(image)
44
- elif hasattr(image, 'name'): # Gradio file object
45
- image = Image.open(image.name)
 
 
 
 
46
 
47
- # Convert to RGB if necessary
48
- if image.mode != 'RGB':
49
- image = image.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Process the image and question
52
- inputs = processor(images=image, text=question, return_tensors="pt")
 
 
 
 
53
 
54
- # Generate predictions
55
- predictions = model.generate(**inputs, max_new_tokens=512)
56
 
57
- # Decode the output
58
- extracted_text = processor.decode(predictions[0], skip_special_tokens=True)
 
 
 
 
59
 
60
- # Try to parse the extracted text into structured data
61
- structured_data = parse_extracted_data(extracted_text)
 
 
 
 
 
 
 
 
 
 
62
 
63
- return extracted_text, structured_data
64
-
65
- except Exception as e:
66
- return f"Error processing image: {str(e)}", None
67
-
68
- def parse_extracted_data(text):
69
- """
70
- Parse the extracted text to create structured data
71
- This is a basic parser - you might need to customize based on your needs
72
- """
73
- try:
74
- # Look for table-like patterns
75
- lines = text.strip().split('\n')
76
- data = []
77
 
78
- # Try to find header and data rows
79
- for line in lines:
80
- if '|' in line: # Table format with pipes
81
- row = [cell.strip() for cell in line.split('|') if cell.strip()]
82
- if row:
83
- data.append(row)
84
- elif '\t' in line: # Tab-separated
85
- row = [cell.strip() for cell in line.split('\t') if cell.strip()]
86
- if row:
87
- data.append(row)
88
- elif ',' in line and not line.startswith('The'): # CSV-like
89
- row = [cell.strip() for cell in line.split(',') if cell.strip()]
90
- if row:
91
- data.append(row)
92
-
93
- if data:
94
- # Create DataFrame
95
- if len(data) > 1:
96
- df = pd.DataFrame(data[1:], columns=data[0])
97
- else:
98
- df = pd.DataFrame(data)
99
- return df
100
 
101
- return None
102
-
103
  except Exception as e:
104
- print(f"Error parsing data: {e}")
105
- return None
106
-
107
- def process_chart(image, custom_question):
108
- """
109
- Main function to process chart and return results
110
- """
111
- if image is None:
112
- return "Please upload an image", None, None
113
-
114
- # Use custom question if provided, otherwise use default
115
- question = custom_question if custom_question.strip() else "Generate underlying data table of the figure below:"
116
-
117
- # Extract data
118
- raw_output, structured_data = extract_chart_data(image, question)
119
-
120
- # Prepare outputs
121
- if structured_data is not None and not structured_data.empty:
122
- # Convert DataFrame to HTML for display
123
- table_html = structured_data.to_html(index=False, classes='table table-striped')
124
- # Convert DataFrame to CSV string for download
125
- csv_output = structured_data.to_csv(index=False)
126
- else:
127
- table_html = "Could not parse data into structured format"
128
- csv_output = None
129
-
130
- return raw_output, table_html, csv_output
131
 
132
  # Create Gradio interface
133
  def create_interface():
134
- with gr.Blocks(title="DePlot Chart Data Extractor", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
135
  gr.Markdown("""
136
  # πŸ“Š DePlot Chart Data Extractor
137
 
138
- Upload any chart or plot image to extract the underlying data, even without visible data labels!
139
- This tool uses Google's DePlot model to understand and extract data from various types of charts.
140
 
141
- **Supported chart types:** Bar charts, line graphs, scatter plots, pie charts, and more!
142
  """)
143
 
144
  with gr.Row():
145
  with gr.Column(scale=1):
146
- # Input section
147
  image_input = gr.Image(
148
- type="pil",
149
- label="Upload Chart Image",
150
  height=400
151
  )
152
 
153
- custom_question = gr.Textbox(
154
- label="Custom Question (optional)",
155
- placeholder="e.g., 'What are the values for each category?' or leave empty for default",
156
- lines=2
 
157
  )
158
 
159
- extract_btn = gr.Button("Extract Data", variant="primary", size="lg")
160
 
161
- with gr.Column(scale=1):
162
- # Output section
163
- with gr.Tab("Raw Output"):
164
- raw_output = gr.Textbox(
165
- label="Extracted Text",
166
- lines=10,
167
- show_copy_button=True
168
- )
169
-
170
- with gr.Tab("Structured Data"):
171
- structured_output = gr.HTML(
172
- label="Parsed Data Table"
173
- )
174
 
175
- # Download section
176
- csv_download = gr.File(
177
- label="Download CSV"
 
178
  )
179
 
180
- # Example images
181
- gr.Markdown("### πŸ“‹ Try these examples:")
182
-
183
- example_images = [
184
- ["examples/bar_chart.png", "Extract the data from this bar chart"],
185
- ["examples/line_graph.png", "What are the trend values over time?"],
186
- ["examples/pie_chart.png", "Give me the percentage breakdown"]
187
- ]
188
 
189
- # Note: You'll need to add example images to your space
 
 
 
 
 
 
 
 
190
 
191
  # Event handlers
192
- def process_and_download(image, question):
193
- raw, table, csv_data = process_chart(image, question)
194
-
195
- if csv_data:
196
- # Create temporary CSV file for download
197
- import tempfile
198
- import os
199
-
200
- # Create a temporary file
201
- temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False)
202
- temp_file.write(csv_data)
203
- temp_file.close()
204
-
205
- return raw, table, temp_file.name
206
- else:
207
- return raw, table, None
208
-
209
  extract_btn.click(
210
- fn=process_and_download,
211
- inputs=[image_input, custom_question],
212
- outputs=[raw_output, structured_output, csv_download]
 
213
  )
214
 
215
- gr.Markdown("""
216
- ### πŸ’‘ Tips for better results:
217
- - Use clear, high-resolution images
218
- - Ensure chart elements are visible and not too small
219
- - Try different custom questions for specific data you need
220
- - Works best with standard chart types (bar, line, scatter, pie)
221
-
222
- ### πŸ”§ Model Information:
223
- This space uses Google's DePlot model, which is specifically trained to extract data from plots and figures.
224
- """)
225
 
226
- return demo
227
 
228
- # Create and launch the interface
229
  if __name__ == "__main__":
230
- demo = create_interface()
231
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
4
+ from PIL import Image, ImageEnhance
5
+ import pandas as pd
6
  import io
7
  import re
 
 
8
 
9
+ # Configuration
10
  MODEL_NAME = "google/deplot"
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load model and processor
14
+ print(f"Loading DePlot model on {DEVICE}...")
15
+ try:
16
+ processor = Pix2StructProcessor.from_pretrained(MODEL_NAME)
17
+ model = Pix2StructForConditionalGeneration.from_pretrained(
18
+ MODEL_NAME,
19
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
20
+ ).to(DEVICE)
21
+ print("βœ… DePlot model loaded successfully!")
22
+ except Exception as e:
23
+ print(f"❌ Failed to load model: {e}")
24
+ raise
25
 
26
+ def preprocess_image(image):
27
+ """Enhance image quality for better data extraction."""
28
+ if image is None:
29
+ return None
30
+
31
+ # Convert to RGB if needed
32
+ if image.mode != 'RGB':
33
+ image = image.convert('RGB')
34
+
35
+ # Enhance contrast and sharpness
36
+ enhancer = ImageEnhance.Contrast(image)
37
+ image = enhancer.enhance(1.2)
38
+
39
+ enhancer = ImageEnhance.Sharpness(image)
40
+ image = enhancer.enhance(1.3)
41
+
42
+ # Resize if too small (minimum 512px width for better results)
43
+ if image.width < 512:
44
+ new_width = 512
45
+ new_height = int(512 * image.height / image.width)
46
+ image = image.resize((new_width, new_height), Image.LANCZOS)
47
+
48
+ return image
49
 
50
+ def parse_deplot_output(raw_text):
51
+ """Parse DePlot output into structured table format."""
52
+ if not raw_text or raw_text.strip() == "":
53
+ return None, "No data extracted from the image."
54
 
55
+ lines = [line.strip() for line in raw_text.split('\n') if line.strip()]
 
 
56
 
57
+ if not lines:
58
+ return None, "Empty output from model."
 
 
 
59
 
60
+ # Try to parse as pipe-separated table
61
+ if '|' in raw_text:
62
+ try:
63
+ # Split by lines and parse each line
64
+ table_data = []
65
+ headers = None
66
+
67
+ for line in lines:
68
+ if '|' in line:
69
+ cells = [cell.strip() for cell in line.split('|')]
70
+ # Remove empty cells at start/end
71
+ cells = [cell for cell in cells if cell]
72
+
73
+ if headers is None:
74
+ headers = cells
75
+ else:
76
+ table_data.append(cells)
77
+
78
+ if headers and table_data:
79
+ # Create DataFrame
80
+ max_cols = max(len(headers), max(len(row) for row in table_data) if table_data else 0)
81
+
82
+ # Pad headers if needed
83
+ while len(headers) < max_cols:
84
+ headers.append(f"Column_{len(headers)+1}")
85
+
86
+ # Pad rows if needed
87
+ for row in table_data:
88
+ while len(row) < max_cols:
89
+ row.append("")
90
+
91
+ df = pd.DataFrame(table_data, columns=headers[:max_cols])
92
+ return df, f"βœ… Successfully extracted table with {len(df)} rows and {len(df.columns)} columns."
93
+
94
+ except Exception as e:
95
+ return None, f"Error parsing table format: {str(e)}"
96
+
97
+ # If not pipe-separated, try to parse as key-value pairs or simple list
98
  try:
99
+ # Look for patterns like "Category: Value" or "Item | Value"
100
+ data_dict = {}
101
+ for line in lines:
102
+ if ':' in line:
103
+ parts = line.split(':', 1)
104
+ if len(parts) == 2:
105
+ key = parts[0].strip()
106
+ value = parts[1].strip()
107
+ data_dict[key] = value
108
 
109
+ if data_dict:
110
+ df = pd.DataFrame(list(data_dict.items()), columns=['Category', 'Value'])
111
+ return df, f"βœ… Extracted {len(df)} key-value pairs."
112
+
113
+ except Exception as e:
114
+ return None, f"Error parsing key-value format: {str(e)}"
115
+
116
+ # If all parsing fails, return raw text in a single-column table
117
+ df = pd.DataFrame([raw_text], columns=['Extracted_Text'])
118
+ return df, "⚠️ Could not parse into structured format. Showing raw extracted text."
119
+
120
+ def extract_table_from_chart(image, prompt_type="default"):
121
+ """Extract data table from chart image using DePlot."""
122
+
123
+ if image is None:
124
+ return None, "Please upload an image.", ""
125
+
126
+ try:
127
+ # Preprocess image
128
+ processed_image = preprocess_image(image)
129
 
130
+ # Define prompts
131
+ prompts = {
132
+ "default": "Generate underlying data table of the figure below:",
133
+ "detailed": "Extract all data points and create a comprehensive table from this chart:",
134
+ "summary": "Summarize the key data from this chart in table format:",
135
+ }
136
 
137
+ prompt = prompts.get(prompt_type, prompts["default"])
 
138
 
139
+ # Prepare inputs
140
+ inputs = processor(
141
+ images=processed_image,
142
+ text=prompt,
143
+ return_tensors="pt"
144
+ ).to(DEVICE)
145
 
146
+ # Generate output
147
+ with torch.no_grad():
148
+ generated_ids = model.generate(
149
+ **inputs,
150
+ max_new_tokens=512,
151
+ do_sample=False,
152
+ num_beams=4,
153
+ temperature=0.0,
154
+ pad_token_id=processor.tokenizer.pad_token_id,
155
+ eos_token_id=processor.tokenizer.eos_token_id,
156
+ early_stopping=True
157
+ )
158
 
159
+ # Decode output
160
+ generated_text = processor.decode(generated_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ # Remove the prompt from output
163
+ clean_text = generated_text.replace(prompt, "").strip()
164
+
165
+ # Clear GPU cache
166
+ if DEVICE == "cuda":
167
+ torch.cuda.empty_cache()
168
+
169
+ # Parse the output
170
+ df, status_msg = parse_deplot_output(clean_text)
171
+
172
+ return df, status_msg, clean_text
 
 
 
 
 
 
 
 
 
 
 
173
 
 
 
174
  except Exception as e:
175
+ # Clear GPU cache on error
176
+ if DEVICE == "cuda":
177
+ torch.cuda.empty_cache()
178
+ return None, f"❌ Error: {str(e)}", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  # Create Gradio interface
181
  def create_interface():
182
+ with gr.Blocks(
183
+ title="DePlot: Chart Data Extraction",
184
+ theme=gr.themes.Soft(),
185
+ css="""
186
+ .gradio-container {
187
+ max-width: 1200px !important;
188
+ }
189
+ """
190
+ ) as interface:
191
+
192
  gr.Markdown("""
193
  # πŸ“Š DePlot Chart Data Extractor
194
 
195
+ Upload a chart image (bar chart, line chart, pie chart, etc.) and extract the underlying data table.
 
196
 
197
+ **Supported formats:** PNG, JPG, JPEG, GIF, BMP
198
  """)
199
 
200
  with gr.Row():
201
  with gr.Column(scale=1):
 
202
  image_input = gr.Image(
203
+ type="pil",
204
+ label="πŸ“ Upload Chart Image",
205
  height=400
206
  )
207
 
208
+ prompt_type = gr.Radio(
209
+ choices=["default", "detailed", "summary"],
210
+ value="default",
211
+ label="🎯 Extraction Mode",
212
+ info="Choose how detailed the extraction should be"
213
  )
214
 
215
+ extract_btn = gr.Button("πŸš€ Extract Data Table", variant="primary", size="lg")
216
 
217
+ with gr.Column(scale=2):
218
+ status_output = gr.Textbox(
219
+ label="πŸ“‹ Status",
220
+ interactive=False,
221
+ max_lines=2
222
+ )
 
 
 
 
 
 
 
223
 
224
+ table_output = gr.Dataframe(
225
+ label="πŸ“Š Extracted Data Table",
226
+ interactive=False,
227
+ wrap=True
228
  )
229
 
230
+ with gr.Accordion("πŸ” Raw Extracted Text", open=False):
231
+ raw_output = gr.Textbox(
232
+ label="Raw DePlot Output",
233
+ interactive=False,
234
+ max_lines=10,
235
+ show_copy_button=True
236
+ )
 
237
 
238
+ # Examples
239
+ gr.Markdown("### πŸ“Έ Try these example charts:")
240
+ gr.Examples(
241
+ examples=[
242
+ ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/deplot_demo.png"],
243
+ ],
244
+ inputs=image_input,
245
+ label="Click to load example"
246
+ )
247
 
248
  # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  extract_btn.click(
250
+ fn=extract_table_from_chart,
251
+ inputs=[image_input, prompt_type],
252
+ outputs=[table_output, status_output, raw_output],
253
+ show_progress=True
254
  )
255
 
256
+ # Auto-extract on image upload
257
+ image_input.change(
258
+ fn=lambda img: extract_table_from_chart(img, "default") if img else (None, "Please upload an image.", ""),
259
+ inputs=image_input,
260
+ outputs=[table_output, status_output, raw_output],
261
+ show_progress=True
262
+ )
 
 
 
263
 
264
+ return interface
265
 
266
+ # Launch the app
267
  if __name__ == "__main__":
268
+ interface = create_interface()
269
+ interface.launch(
270
+ server_name="0.0.0.0", # For Hugging Face Spaces
271
+ server_port=7860, # Standard port for HF Spaces
272
+ share=False, # Don't create public link
273
+ show_error=True,
274
+ show_tips=True
275
+ )