ravi-vc commited on
Commit
fff0038
·
verified ·
1 Parent(s): 9a02ee8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -39
app.py CHANGED
@@ -1,49 +1,42 @@
1
  import gradio as gr
2
- from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
3
  from PIL import Image
4
- import torch
5
 
6
- # Load model & processor
7
- processor = Pix2StructProcessor.from_pretrained("google/deplot")
8
- model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
9
 
10
- # Move model to GPU if available
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model.to(device)
 
 
13
 
14
- # def extract_chart_data(img: Image.Image):
15
- # inputs = processor(images=img, return_tensors="pt").to(device)
16
- # outputs = model.generate(**inputs)
17
- # result = processor.decode(outputs[0], skip_special_tokens=True)
18
- # return result
 
 
19
 
20
- def extract_chart_data(img: Image.Image, header_text:str):
21
- # Provide a header prompt for VQA
22
- # header_text = "Convert this bar chart into a table of quarter and cost values"
23
- print("Header text:", header_text)
24
-
25
- inputs = processor(images=img, text=header_text, return_tensors="pt").to(device)
26
- outputs = model.generate(**inputs)
27
- all_results = [processor.decode(seq, skip_special_tokens=True) for seq in outputs]
28
- print(all_results)
29
- result = processor.decode(outputs[0], skip_special_tokens=True)
30
- return result
31
 
 
32
 
33
- # Gradio interface
34
- iface = gr.Interface(
35
- fn=extract_chart_data,
36
- # inputs=gr.Image(type="pil"),
37
- inputs=[
38
- gr.Image(type="pil", label="Upload Chart Image"), # Image input
39
- gr.Textbox(
40
- label="Instruction / Header Text",
41
- placeholder="Enter instructions like 'Convert this bar chart into a table of quarter and cost values'"
42
- ) # Text input
43
- ],
44
- outputs="text",
45
- title="DePlot Chart Data Extractor 1",
46
- description="Upload a chart image and get its extracted data/description."
47
  )
48
 
49
- iface.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForSeq2SeqLM
3
  from PIL import Image
4
+ import json
5
 
6
+ # Load DePlot (vision-to-table)
7
+ deplot_model = AutoModelForSeq2SeqLM.from_pretrained("google/deplot")
8
+ deplot_processor = AutoProcessor.from_pretrained("google/deplot")
9
 
10
+ def extract_chart(image):
11
+ # Step 1: Run DePlot
12
+ inputs = deplot_processor(images=image, text="Generate table from chart.", return_tensors="pt")
13
+ outputs = deplot_model.generate(**inputs, max_new_tokens=512)
14
+ table = deplot_processor.decode(outputs[0], skip_special_tokens=True)
15
 
16
+ # Step 2: Create dummy structured JSON (later we plug in detection+OCR)
17
+ structured_json = {
18
+ "metadata": {"title": "Demo Chart", "chart_type": "bar", "confidence": 0.5},
19
+ "axes": {"x_axis": {"label": "X", "ticks": []}, "y_axis": {"label": "Y", "ticks": []}},
20
+ "series": [],
21
+ "legend": {"entries": []}
22
+ }
23
 
24
+ # Step 3: Merge (for now, just attach DePlot’s output)
25
+ merged_output = {
26
+ "structured_json": structured_json,
27
+ "deplot_table": table,
28
+ "fusion_notes": "Fusion layer not implemented yet, just showing both outputs."
29
+ }
 
 
 
 
 
30
 
31
+ return json.dumps(merged_output, indent=2)
32
 
33
+ demo = gr.Interface(
34
+ fn=extract_chart,
35
+ inputs=gr.Image(type="pil"),
36
+ outputs="json",
37
+ title="Chart-to-JSON Extractor (Prototype)",
38
+ description="Uploads a chart, extracts structured JSON (dummy) and DePlot table side-by-side."
 
 
 
 
 
 
 
 
39
  )
40
 
41
+ if __name__ == "__main__":
42
+ demo.launch()