KangLiao commited on
Commit
e07075c
·
1 Parent(s): a778fea
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
  import numpy as np
5
  import spaces # Import spaces for ZeroGPU compatibility
@@ -47,6 +48,14 @@ checkpoint_path = "checkpoints/Puffin-Base.pth"
47
  checkpoint = torch.load(checkpoint_path)
48
  info = model.load_state_dict(checkpoint, strict=False)
49
 
 
 
 
 
 
 
 
 
50
  def extract_up_lat_figs(fig_dict):
51
  fig_up, fig_lat = None, None
52
  others = {}
@@ -100,11 +109,15 @@ def camera_understanding(image_src, question, seed, progress=gr.Progress(track_t
100
  single_batch["latitude_field"] = cam[2:].unsqueeze(0)
101
 
102
  figs = make_perspective_figures(single_batch, single_batch, n_pairs=1)
103
-
104
-
105
- fig_up, fig_lat, _ = extract_up_lat_figs(figs)
106
-
107
- return text, fig_up, fig_lat
 
 
 
 
108
 
109
 
110
  @torch.inference_mode()
 
1
  import gradio as gr
2
  import torch
3
+ import io
4
  from PIL import Image
5
  import numpy as np
6
  import spaces # Import spaces for ZeroGPU compatibility
 
48
  checkpoint = torch.load(checkpoint_path)
49
  info = model.load_state_dict(checkpoint, strict=False)
50
 
51
+ def fig_to_image(fig):
52
+ buf = io.BytesIO()
53
+ fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
54
+ buf.seek(0)
55
+ img = Image.open(buf).convert('RGB')
56
+ buf.close()
57
+ return img
58
+
59
  def extract_up_lat_figs(fig_dict):
60
  fig_up, fig_lat = None, None
61
  others = {}
 
109
  single_batch["latitude_field"] = cam[2:].unsqueeze(0)
110
 
111
  figs = make_perspective_figures(single_batch, single_batch, n_pairs=1)
112
+ up_img = lat_img = None
113
+ for k, fig in figs.items():
114
+ if "up_field" in k:
115
+ up_img = fig_to_image(fig)
116
+ elif "latitude_field" in k:
117
+ lat_img = fig_to_image(fig)
118
+ plt.close(fig)
119
+
120
+ return text, up_img, lat_img
121
 
122
 
123
  @torch.inference_mode()