KangLiao commited on
Commit
06869b4
·
1 Parent(s): f18fdea
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -47,6 +47,18 @@ checkpoint_path = "checkpoints/Puffin-Base.pth"
47
  checkpoint = torch.load(checkpoint_path)
48
  info = model.load_state_dict(checkpoint, strict=False)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  @torch.inference_mode()
52
  @spaces.GPU(duration=120)
@@ -88,15 +100,11 @@ def camera_understanding(image_src, question, seed, progress=gr.Progress(track_t
88
  single_batch["latitude_field"] = cam[2:].unsqueeze(0)
89
 
90
  figs = make_perspective_figures(single_batch, single_batch, n_pairs=1)
91
- imgs = []
92
- for k, fig in figs.items():
93
- fig.canvas.draw()
94
- img = np.array(fig.canvas.renderer.buffer_rgba())
95
- imgs.append(img)
96
- plt.close(fig)
97
- merged_imgs = np.concatenate(imgs, axis=1)
98
 
99
- return text, merged_imgs
100
 
101
 
102
  @torch.inference_mode()
@@ -192,7 +200,8 @@ with gr.Blocks(css=css) as demo:
192
  understanding_button = gr.Button("Chat")
193
  understanding_output = gr.Textbox(label="Response")
194
 
195
- camera_output = gr.Gallery(label="Camera Maps", columns=1, rows=1)
 
196
 
197
  with gr.Accordion("Advanced options", open=False):
198
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
@@ -215,7 +224,7 @@ with gr.Blocks(css=css) as demo:
215
  understanding_button.click(
216
  camera_understanding,
217
  inputs=[image_input, und_seed_input],
218
- outputs=[understanding_output, camera_output]
219
  )
220
 
221
  demo.launch(share=True)
 
47
  checkpoint = torch.load(checkpoint_path)
48
  info = model.load_state_dict(checkpoint, strict=False)
49
 
50
+ def extract_up_lat_figs(fig_dict):
51
+ fig_up, fig_lat = None, None
52
+ others = {}
53
+ for k, fig in fig_dict.items():
54
+ if ("up_field" in k) and (fig_up is None):
55
+ fig_up = fig
56
+ elif ("latitude_field" in k) and (fig_lat is None):
57
+ fig_lat = fig
58
+ else:
59
+ others[k] = fig
60
+ return fig_up, fig_lat, others
61
+
62
 
63
  @torch.inference_mode()
64
  @spaces.GPU(duration=120)
 
100
  single_batch["latitude_field"] = cam[2:].unsqueeze(0)
101
 
102
  figs = make_perspective_figures(single_batch, single_batch, n_pairs=1)
103
+
104
+
105
+ fig_up, fig_lat, _ = extract_up_lat_figs(figs)
 
 
 
 
106
 
107
+ return text, fig_up, fig_lat
108
 
109
 
110
  @torch.inference_mode()
 
200
  understanding_button = gr.Button("Chat")
201
  understanding_output = gr.Textbox(label="Response")
202
 
203
+ camera1 = gr.Gallery(label="Camera Maps", columns=1, rows=1)
204
+ camera2 = gr.Gallery(label="Camera Maps", columns=1, rows=1)
205
 
206
  with gr.Accordion("Advanced options", open=False):
207
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
 
224
  understanding_button.click(
225
  camera_understanding,
226
  inputs=[image_input, und_seed_input],
227
+ outputs=[understanding_output, camera1, camera2]
228
  )
229
 
230
  demo.launch(share=True)