vsamasworm commited on
Commit
fbb8705
·
1 Parent(s): ee8949b

fix NameError

Browse files
Files changed (2) hide show
  1. app.py +64 -0
  2. inference.py +1 -66
app.py CHANGED
@@ -10,6 +10,7 @@ from vision_tower import VGGT_OriAny_Ref
10
  from inference import *
11
  from app_utils import *
12
  from axis_renderer import BlendRenderer
 
13
 
14
  from huggingface_hub import hf_hub_download
15
  ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
@@ -32,6 +33,69 @@ print('Model loaded.')
32
 
33
  axis_renderer = BlendRenderer(RENDER_FILE)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # ====== 工具函数:安全图像处理 ======
37
  def safe_image_input(image):
 
10
  from inference import *
11
  from app_utils import *
12
  from axis_renderer import BlendRenderer
13
+ import spaces
14
 
15
  from huggingface_hub import hf_hub_download
16
  ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
 
33
 
34
  axis_renderer = BlendRenderer(RENDER_FILE)
35
 
36
+ @spaces.GPU
37
+ @torch.no_grad()
38
+ def inf_single_batch(batch):
39
+ device = model.get_device()
40
+ batch_img_inputs = batch # (B, S, 3, H, W)
41
+ # print(batch_img_inputs.shape)
42
+ B, S, C, H, W = batch_img_inputs.shape
43
+ pose_enc = model(batch_img_inputs) # (B, S, D) S = 1
44
+
45
+ pose_enc = pose_enc.view(B*S, -1)
46
+ angle_az_pred = torch.argmax(pose_enc[:, 0:360] , dim=-1)
47
+ angle_el_pred = torch.argmax(pose_enc[:, 360:360+180] , dim=-1) - 90
48
+ angle_ro_pred = torch.argmax(pose_enc[:, 360+180:360+180+360] , dim=-1) - 180
49
+
50
+ # ori_val
51
+ # trained with BCE loss
52
+ distribute = F.sigmoid(pose_enc[:, 0:360]).cpu().float().numpy()
53
+ # trained with CE loss
54
+ # distribute = pose_enc[:, 0:360].cpu().float().numpy()
55
+ alpha_pred = val_fit_alpha(distribute = distribute)
56
+
57
+ # ref_val
58
+ if S > 1:
59
+ ref_az_pred = angle_az_pred.reshape(B,S)[:,0]
60
+ ref_el_pred = angle_el_pred.reshape(B,S)[:,0]
61
+ ref_ro_pred = angle_ro_pred.reshape(B,S)[:,0]
62
+ ref_alpha_pred = alpha_pred.reshape(B,S)[:,0]
63
+ rel_az_pred = angle_az_pred.reshape(B,S)[:,1]
64
+ rel_el_pred = angle_el_pred.reshape(B,S)[:,1]
65
+ rel_ro_pred = angle_ro_pred.reshape(B,S)[:,1]
66
+ else:
67
+ ref_az_pred = angle_az_pred[0]
68
+ ref_el_pred = angle_el_pred[0]
69
+ ref_ro_pred = angle_ro_pred[0]
70
+ ref_alpha_pred = alpha_pred[0]
71
+ rel_az_pred = 0.
72
+ rel_el_pred = 0.
73
+ rel_ro_pred = 0.
74
+
75
+ ans_dict = {
76
+ 'ref_az_pred': ref_az_pred,
77
+ 'ref_el_pred': ref_el_pred,
78
+ 'ref_ro_pred': ref_ro_pred,
79
+ 'ref_alpha_pred' : ref_alpha_pred,
80
+ 'rel_az_pred' : rel_az_pred,
81
+ 'rel_el_pred' : rel_el_pred,
82
+ 'rel_ro_pred' : rel_ro_pred,
83
+ }
84
+
85
+ return ans_dict
86
+
87
+ # input PIL Image
88
+ @torch.no_grad()
89
+ def inf_single_case(image_ref, image_tgt):
90
+ if image_tgt is None:
91
+ image_list = [image_ref]
92
+ else:
93
+ image_list = [image_ref, image_tgt]
94
+ image_tensors = preprocess_images(image_list, mode="pad").to('cuda')
95
+ ans_dict = inf_single_batch(batch=image_tensors.unsqueeze(0))
96
+ print(ans_dict)
97
+ return ans_dict
98
+
99
 
100
  # ====== 工具函数:安全图像处理 ======
101
  def safe_image_input(image):
inference.py CHANGED
@@ -9,7 +9,7 @@ from scipy.special import i0
9
  from scipy.optimize import curve_fit
10
  from scipy.integrate import trapezoid
11
  from functools import partial
12
- import spaces
13
 
14
  def von_mises_pdf_alpha_numpy(alpha, x, mu, kappa):
15
  normalization = 2 * np.pi
@@ -175,68 +175,3 @@ def preprocess_images(image_list, mode="crop"):
175
  images = images.unsqueeze(0)
176
 
177
  return images
178
-
179
- @torch.no_grad()
180
- def inf_single_batch(batch):
181
- global model
182
- device = model.get_device()
183
- batch_img_inputs = batch # (B, S, 3, H, W)
184
- # print(batch_img_inputs.shape)
185
- B, S, C, H, W = batch_img_inputs.shape
186
- pose_enc = model(batch_img_inputs) # (B, S, D) S = 1
187
-
188
- pose_enc = pose_enc.view(B*S, -1)
189
- angle_az_pred = torch.argmax(pose_enc[:, 0:360] , dim=-1)
190
- angle_el_pred = torch.argmax(pose_enc[:, 360:360+180] , dim=-1) - 90
191
- angle_ro_pred = torch.argmax(pose_enc[:, 360+180:360+180+360] , dim=-1) - 180
192
-
193
- # ori_val
194
- # trained with BCE loss
195
- distribute = F.sigmoid(pose_enc[:, 0:360]).cpu().float().numpy()
196
- # trained with CE loss
197
- # distribute = pose_enc[:, 0:360].cpu().float().numpy()
198
- alpha_pred = val_fit_alpha(distribute = distribute)
199
-
200
- # ref_val
201
- if S > 1:
202
- ref_az_pred = angle_az_pred.reshape(B,S)[:,0]
203
- ref_el_pred = angle_el_pred.reshape(B,S)[:,0]
204
- ref_ro_pred = angle_ro_pred.reshape(B,S)[:,0]
205
- ref_alpha_pred = alpha_pred.reshape(B,S)[:,0]
206
- rel_az_pred = angle_az_pred.reshape(B,S)[:,1]
207
- rel_el_pred = angle_el_pred.reshape(B,S)[:,1]
208
- rel_ro_pred = angle_ro_pred.reshape(B,S)[:,1]
209
- else:
210
- ref_az_pred = angle_az_pred[0]
211
- ref_el_pred = angle_el_pred[0]
212
- ref_ro_pred = angle_ro_pred[0]
213
- ref_alpha_pred = alpha_pred[0]
214
- rel_az_pred = 0.
215
- rel_el_pred = 0.
216
- rel_ro_pred = 0.
217
-
218
- ans_dict = {
219
- 'ref_az_pred': ref_az_pred,
220
- 'ref_el_pred': ref_el_pred,
221
- 'ref_ro_pred': ref_ro_pred,
222
- 'ref_alpha_pred' : ref_alpha_pred,
223
- 'rel_az_pred' : rel_az_pred,
224
- 'rel_el_pred' : rel_el_pred,
225
- 'rel_ro_pred' : rel_ro_pred,
226
- }
227
-
228
- return ans_dict
229
-
230
- # input PIL Image
231
- @spaces.GPU
232
- @torch.no_grad()
233
- def inf_single_case(image_ref, image_tgt):
234
- global model
235
- if image_tgt is None:
236
- image_list = [image_ref]
237
- else:
238
- image_list = [image_ref, image_tgt]
239
- image_tensors = preprocess_images(image_list, mode="pad").to('cuda')
240
- ans_dict = inf_single_batch(batch=image_tensors.unsqueeze(0))
241
- print(ans_dict)
242
- return ans_dict
 
9
  from scipy.optimize import curve_fit
10
  from scipy.integrate import trapezoid
11
  from functools import partial
12
+
13
 
14
  def von_mises_pdf_alpha_numpy(alpha, x, mu, kappa):
15
  normalization = 2 * np.pi
 
175
  images = images.unsqueeze(0)
176
 
177
  return images