ravi-vc commited on
Commit
6af56a3
·
verified ·
1 Parent(s): fff0038

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -1,19 +1,24 @@
 
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForSeq2SeqLM
3
  from PIL import Image
4
  import json
5
 
6
- # Load DePlot (vision-to-table)
7
- deplot_model = AutoModelForSeq2SeqLM.from_pretrained("google/deplot")
8
- deplot_processor = AutoProcessor.from_pretrained("google/deplot")
 
 
 
 
9
 
10
  def extract_chart(image):
11
  # Step 1: Run DePlot
12
- inputs = deplot_processor(images=image, text="Generate table from chart.", return_tensors="pt")
13
- outputs = deplot_model.generate(**inputs, max_new_tokens=512)
14
- table = deplot_processor.decode(outputs[0], skip_special_tokens=True)
15
 
16
- # Step 2: Create dummy structured JSON (later we plug in detection+OCR)
17
  structured_json = {
18
  "metadata": {"title": "Demo Chart", "chart_type": "bar", "confidence": 0.5},
19
  "axes": {"x_axis": {"label": "X", "ticks": []}, "y_axis": {"label": "Y", "ticks": []}},
@@ -21,7 +26,7 @@ def extract_chart(image):
21
  "legend": {"entries": []}
22
  }
23
 
24
- # Step 3: Merge (for now, just attach DePlot’s output)
25
  merged_output = {
26
  "structured_json": structured_json,
27
  "deplot_table": table,
 
1
+ import os
2
  import gradio as gr
3
+ from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
4
  from PIL import Image
5
  import json
6
 
7
+ # Fix threading error
8
+ os.environ["OMP_NUM_THREADS"] = "1"
9
+
10
+ # Load DePlot
11
+ model_id = "google/deplot"
12
+ processor = Pix2StructProcessor.from_pretrained(model_id)
13
+ model = Pix2StructForConditionalGeneration.from_pretrained(model_id)
14
 
15
  def extract_chart(image):
16
  # Step 1: Run DePlot
17
+ inputs = processor(images=image, text="Generate table from chart.", return_tensors="pt")
18
+ predictions = model.generate(**inputs, max_new_tokens=512)
19
+ table = processor.decode(predictions[0], skip_special_tokens=True)
20
 
21
+ # Step 2: Dummy structured JSON
22
  structured_json = {
23
  "metadata": {"title": "Demo Chart", "chart_type": "bar", "confidence": 0.5},
24
  "axes": {"x_axis": {"label": "X", "ticks": []}, "y_axis": {"label": "Y", "ticks": []}},
 
26
  "legend": {"entries": []}
27
  }
28
 
29
+ # Step 3: Merge outputs
30
  merged_output = {
31
  "structured_json": structured_json,
32
  "deplot_table": table,