ravi-vc commited on
Commit
5e40923
Β·
verified Β·
1 Parent(s): 36bfbe1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +353 -0
app.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import (
4
+ BlipProcessor, BlipForConditionalGeneration,
5
+ TrOCRProcessor, VisionEncoderDecoderModel,
6
+ AutoProcessor, AutoModelForCausalLM
7
+ )
8
+ from PIL import Image
9
+ import easyocr
10
+ import matplotlib.pyplot as plt
11
+ import pandas as pd
12
+ import numpy as np
13
+ import cv2
14
+ import io
15
+ import base64
16
+
17
+ class ChartAnalyzer:
18
+ def __init__(self):
19
+ # Load models
20
+ self.load_models()
21
+
22
+ def load_models(self):
23
+ """Load all required models"""
24
+ try:
25
+ # BLIP for image captioning and understanding
26
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
27
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
28
+
29
+ # TrOCR for text extraction
30
+ self.trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
31
+ self.trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
32
+
33
+ # EasyOCR for backup text extraction
34
+ self.ocr_reader = easyocr.Reader(['en'])
35
+
36
+ # Florence-2 for advanced understanding (if available)
37
+ try:
38
+ self.florence_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base")
39
+ self.florence_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base")
40
+ self.florence_available = True
41
+ except:
42
+ self.florence_available = False
43
+
44
+ except Exception as e:
45
+ print(f"Error loading models: {e}")
46
+
47
+ def analyze_chart(self, image, analysis_type="comprehensive"):
48
+ """Main function to analyze charts"""
49
+ if image is None:
50
+ return "Please upload an image first."
51
+
52
+ results = {}
53
+
54
+ try:
55
+ # Convert to PIL Image if needed
56
+ if not isinstance(image, Image.Image):
57
+ image = Image.fromarray(image).convert('RGB')
58
+
59
+ # Basic image understanding with BLIP
60
+ results['description'] = self.get_image_description(image)
61
+
62
+ # Extract text using multiple methods
63
+ results['extracted_text'] = self.extract_text_multi_method(image)
64
+
65
+ # Chart type detection
66
+ results['chart_type'] = self.detect_chart_type(image, results['description'])
67
+
68
+ # Data extraction (if possible)
69
+ if analysis_type in ["comprehensive", "data_extraction"]:
70
+ results['data_points'] = self.extract_data_points(image, results['chart_type'])
71
+
72
+ # Advanced analysis with Florence-2 (if available)
73
+ if self.florence_available and analysis_type == "comprehensive":
74
+ results['advanced_analysis'] = self.florence_analysis(image)
75
+
76
+ return self.format_results(results)
77
+
78
+ except Exception as e:
79
+ return f"Error analyzing chart: {str(e)}"
80
+
81
+ def get_image_description(self, image):
82
+ """Get image description using BLIP"""
83
+ try:
84
+ inputs = self.blip_processor(image, return_tensors="pt")
85
+ out = self.blip_model.generate(**inputs, max_length=100)
86
+ description = self.blip_processor.decode(out[0], skip_special_tokens=True)
87
+ return description
88
+ except:
89
+ return "Unable to generate description"
90
+
91
+ def extract_text_multi_method(self, image):
92
+ """Extract text using multiple OCR methods"""
93
+ extracted_texts = {}
94
+
95
+ # Method 1: TrOCR
96
+ try:
97
+ pixel_values = self.trocr_processor(image, return_tensors="pt").pixel_values
98
+ generated_ids = self.trocr_model.generate(pixel_values)
99
+ trocr_text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
100
+ extracted_texts['TrOCR'] = trocr_text
101
+ except:
102
+ extracted_texts['TrOCR'] = "Failed"
103
+
104
+ # Method 2: EasyOCR
105
+ try:
106
+ # Convert PIL to numpy array
107
+ image_np = np.array(image)
108
+ ocr_results = self.ocr_reader.readtext(image_np)
109
+ easyocr_text = ' '.join([result[1] for result in ocr_results])
110
+ extracted_texts['EasyOCR'] = easyocr_text
111
+ except:
112
+ extracted_texts['EasyOCR'] = "Failed"
113
+
114
+ return extracted_texts
115
+
116
+ def detect_chart_type(self, image, description):
117
+ """Detect chart type based on image analysis"""
118
+ description_lower = description.lower()
119
+
120
+ chart_keywords = {
121
+ 'bar_chart': ['bar', 'column', 'histogram'],
122
+ 'line_chart': ['line', 'trend', 'time series'],
123
+ 'pie_chart': ['pie', 'circular', 'slice'],
124
+ 'scatter_plot': ['scatter', 'correlation', 'points'],
125
+ 'area_chart': ['area', 'filled'],
126
+ 'box_plot': ['box', 'whisker'],
127
+ 'heatmap': ['heat', 'color coded', 'matrix']
128
+ }
129
+
130
+ for chart_type, keywords in chart_keywords.items():
131
+ if any(keyword in description_lower for keyword in keywords):
132
+ return chart_type.replace('_', ' ').title()
133
+
134
+ return "Unknown Chart Type"
135
+
136
+ def extract_data_points(self, image, chart_type):
137
+ """Attempt to extract data points (simplified approach)"""
138
+ try:
139
+ # This is a simplified version - real implementation would be more sophisticated
140
+ # Convert to grayscale for analysis
141
+ image_np = np.array(image.convert('L'))
142
+
143
+ # Basic edge detection
144
+ edges = cv2.Canny(image_np, 50, 150)
145
+
146
+ # Find contours
147
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
148
+
149
+ data_info = {
150
+ 'contours_found': len(contours),
151
+ 'image_dimensions': image_np.shape,
152
+ 'note': 'This is a simplified data extraction. Advanced algorithms needed for accurate data point extraction.'
153
+ }
154
+
155
+ return data_info
156
+
157
+ except Exception as e:
158
+ return f"Data extraction failed: {str(e)}"
159
+
160
+ def florence_analysis(self, image):
161
+ """Advanced analysis using Florence-2"""
162
+ if not self.florence_available:
163
+ return "Florence-2 model not available"
164
+
165
+ try:
166
+ # Florence-2 prompts for different tasks
167
+ prompts = [
168
+ "<OD>", # Object Detection
169
+ "<DENSE_REGION_CAPTION>", # Dense captioning
170
+ "<OCR_WITH_REGION>" # OCR with regions
171
+ ]
172
+
173
+ results = {}
174
+ for prompt in prompts:
175
+ inputs = self.florence_processor(text=prompt, images=image, return_tensors="pt")
176
+ generated_ids = self.florence_model.generate(
177
+ input_ids=inputs["input_ids"],
178
+ pixel_values=inputs["pixel_values"],
179
+ max_new_tokens=1024,
180
+ num_beams=3
181
+ )
182
+ generated_text = self.florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
183
+ results[prompt] = generated_text
184
+
185
+ return results
186
+ except:
187
+ return "Florence-2 analysis failed"
188
+
189
+ def format_results(self, results):
190
+ """Format results for display"""
191
+ formatted = "# Chart Analysis Results\n\n"
192
+
193
+ if 'description' in results:
194
+ formatted += f"## Image Description\n{results['description']}\n\n"
195
+
196
+ if 'chart_type' in results:
197
+ formatted += f"## Chart Type\n{results['chart_type']}\n\n"
198
+
199
+ if 'extracted_text' in results:
200
+ formatted += "## Extracted Text\n"
201
+ for method, text in results['extracted_text'].items():
202
+ formatted += f"**{method}:** {text}\n\n"
203
+
204
+ if 'data_points' in results:
205
+ formatted += f"## Data Analysis\n{results['data_points']}\n\n"
206
+
207
+ if 'advanced_analysis' in results:
208
+ formatted += f"## Advanced Analysis\n{results['advanced_analysis']}\n\n"
209
+
210
+ return formatted
211
+
212
+ # Initialize the analyzer
213
+ analyzer = ChartAnalyzer()
214
+
215
+ # Create Gradio interface
216
+ def analyze_uploaded_chart(image, analysis_type):
217
+ return analyzer.analyze_chart(image, analysis_type)
218
+
219
+ # Create the Gradio app
220
+ with gr.Blocks(title="Chart Analyzer & Data Extractor", theme=gr.themes.Soft()) as demo:
221
+ gr.Markdown("# πŸ“Š Chart Analyzer & Data Extractor")
222
+ gr.Markdown("Upload a chart image to extract data and analyze its contents using multiple AI models including BLIP, TrOCR, and Florence-2.")
223
+
224
+ with gr.Row():
225
+ with gr.Column(scale=1):
226
+ gr.Markdown("## πŸ“ Upload Your Chart")
227
+
228
+ # Multiple upload options
229
+ with gr.Tabs():
230
+ with gr.Tab("πŸ“€ Upload Image"):
231
+ image_input = gr.Image(
232
+ type="pil",
233
+ label="Upload Chart Image",
234
+ height=400,
235
+ sources=["upload", "webcam", "clipboard"],
236
+ format="png"
237
+ )
238
+ gr.Markdown("**Supported formats:** PNG, JPG, JPEG, GIF, BMP")
239
+ gr.Markdown("**Max size:** 10MB")
240
+
241
+ with gr.Tab("πŸ”— From URL"):
242
+ url_input = gr.Textbox(
243
+ label="Image URL",
244
+ placeholder="https://example.com/chart.png"
245
+ )
246
+ load_url_btn = gr.Button("Load from URL")
247
+
248
+ # Analysis options
249
+ gr.Markdown("## βš™οΈ Analysis Settings")
250
+ analysis_type = gr.Dropdown(
251
+ choices=["basic", "comprehensive", "data_extraction"],
252
+ value="comprehensive",
253
+ label="Analysis Type",
254
+ info="Choose the depth of analysis"
255
+ )
256
+
257
+ with gr.Accordion("Advanced Options", open=False):
258
+ confidence_threshold = gr.Slider(
259
+ minimum=0.1,
260
+ maximum=1.0,
261
+ value=0.5,
262
+ label="OCR Confidence Threshold"
263
+ )
264
+ use_florence = gr.Checkbox(
265
+ label="Use Florence-2 (Advanced Analysis)",
266
+ value=True
267
+ )
268
+
269
+ analyze_btn = gr.Button("πŸ” Analyze Chart", variant="primary", size="lg")
270
+ clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
271
+
272
+ with gr.Column(scale=2):
273
+ gr.Markdown("## πŸ“Š Analysis Results")
274
+ output = gr.Markdown(
275
+ value="Upload an image and click 'Analyze Chart' to see results here.",
276
+ label="Results"
277
+ )
278
+
279
+ # Additional output components
280
+ with gr.Accordion("Raw Data Export", open=False):
281
+ json_output = gr.JSON(label="Structured Data")
282
+ csv_download = gr.File(label="Download CSV", visible=False)
283
+
284
+ # Function to load image from URL
285
+ def load_image_from_url(url):
286
+ try:
287
+ import requests
288
+ response = requests.get(url)
289
+ response.raise_for_status()
290
+ image = Image.open(io.BytesIO(response.content))
291
+ return image, "Image loaded successfully!"
292
+ except Exception as e:
293
+ return None, f"Error loading image: {str(e)}"
294
+
295
+ # Enhanced analysis function
296
+ def analyze_uploaded_chart(image, analysis_type, confidence_threshold, use_florence):
297
+ if image is None:
298
+ return "Please upload an image first.", {}, None
299
+
300
+ try:
301
+ result = analyzer.analyze_chart(image, analysis_type)
302
+
303
+ # Create structured data for JSON output
304
+ structured_data = {
305
+ "analysis_type": analysis_type,
306
+ "confidence_threshold": confidence_threshold,
307
+ "models_used": ["BLIP", "TrOCR", "EasyOCR"],
308
+ "timestamp": pd.Timestamp.now().isoformat()
309
+ }
310
+
311
+ if use_florence and analyzer.florence_available:
312
+ structured_data["models_used"].append("Florence-2")
313
+
314
+ return result, structured_data, None
315
+
316
+ except Exception as e:
317
+ error_msg = f"Error analyzing chart: {str(e)}"
318
+ return error_msg, {"error": error_msg}, None
319
+
320
+ # Clear function
321
+ def clear_all():
322
+ return None, "Upload an image and click 'Analyze Chart' to see results here.", {}, None
323
+
324
+ # Examples
325
+ gr.Examples(
326
+ examples=[
327
+ ["https://via.placeholder.com/600x400/0066CC/FFFFFF?text=Sample+Bar+Chart", "comprehensive"],
328
+ ["https://via.placeholder.com/600x400/FF6B35/FFFFFF?text=Sample+Line+Chart", "data_extraction"],
329
+ ],
330
+ inputs=[image_input, analysis_type],
331
+ label="Try these examples:"
332
+ )
333
+
334
+ # Event handlers
335
+ analyze_btn.click(
336
+ fn=analyze_uploaded_chart,
337
+ inputs=[image_input, analysis_type, confidence_threshold, use_florence],
338
+ outputs=[output, json_output, csv_download]
339
+ )
340
+
341
+ load_url_btn.click(
342
+ fn=load_image_from_url,
343
+ inputs=[url_input],
344
+ outputs=[image_input, output]
345
+ )
346
+
347
+ clear_btn.click(
348
+ fn=clear_all,
349
+ outputs=[image_input, output, json_output, csv_download]
350
+ )
351
+
352
+ if __name__ == "__main__":
353
+ demo.launch()