2toINF commited on
Commit
6d302b7
·
verified ·
1 Parent(s): 43de723

Update modeling_xvla.py

Browse files
Files changed (1) hide show
  1. modeling_xvla.py +3 -4
modeling_xvla.py CHANGED
@@ -196,15 +196,14 @@ class XVLA(PreTrainedModel):
196
  enc = self.forward_vlm(input_ids, image_input, image_mask)
197
 
198
  B = input_ids.shape[0]
199
- device = input_ids.device
200
  D = self.action_space.dim_action
201
 
202
- x1 = torch.randn(B, self.num_actions, D, device=device)
203
  action = torch.zeros_like(x1)
204
 
205
  steps = max(1, int(steps))
206
  for i in range(steps, 0, -1):
207
- t = torch.full((B,), i / steps, device=device)
208
  x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
209
  proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
210
  action = self.transformer(
@@ -277,7 +276,7 @@ class XVLA(PreTrainedModel):
277
 
278
  # Inference
279
  steps = int(payload.get("steps", 10))
280
- action = self.generate_actions(**inputs, steps=steps).squeeze(0).cpu().numpy()
281
  return JSONResponse({"action": action.tolist()})
282
 
283
  except Exception:
 
196
  enc = self.forward_vlm(input_ids, image_input, image_mask)
197
 
198
  B = input_ids.shape[0]
 
199
  D = self.action_space.dim_action
200
 
201
+ x1 = torch.randn(B, self.num_actions, D, device=proprio.device, dtype=proprio.dtype)
202
  action = torch.zeros_like(x1)
203
 
204
  steps = max(1, int(steps))
205
  for i in range(steps, 0, -1):
206
+ t = torch.full((B,), i / steps, device=proprio.device, dtype=proprio.dtype)
207
  x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
208
  proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
209
  action = self.transformer(
 
276
 
277
  # Inference
278
  steps = int(payload.get("steps", 10))
279
+ action = self.generate_actions(**inputs, steps=steps).squeeze(0).float().cpu().numpy()
280
  return JSONResponse({"action": action.tolist()})
281
 
282
  except Exception: