ravi-vc commited on
Commit
1714f2b
·
verified ·
1 Parent(s): dfbe417

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -11,12 +11,22 @@ model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
 
 
 
 
 
 
14
  def extract_chart_data(img: Image.Image):
15
- inputs = processor(images=img, return_tensors="pt").to(device)
 
 
 
16
  outputs = model.generate(**inputs)
17
  result = processor.decode(outputs[0], skip_special_tokens=True)
18
  return result
19
 
 
20
  # Gradio interface
21
  iface = gr.Interface(
22
  fn=extract_chart_data,
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
14
+ # def extract_chart_data(img: Image.Image):
15
+ # inputs = processor(images=img, return_tensors="pt").to(device)
16
+ # outputs = model.generate(**inputs)
17
+ # result = processor.decode(outputs[0], skip_special_tokens=True)
18
+ # return result
19
+
20
  def extract_chart_data(img: Image.Image):
21
+ # Provide a header prompt for VQA
22
+ header_text = "Extract chart data"
23
+
24
+ inputs = processor(images=img, text=header_text, return_tensors="pt").to(device)
25
  outputs = model.generate(**inputs)
26
  result = processor.decode(outputs[0], skip_special_tokens=True)
27
  return result
28
 
29
+
30
  # Gradio interface
31
  iface = gr.Interface(
32
  fn=extract_chart_data,