Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from diffusers import DiffusionPipeline | |
| class TrajPipeline(DiffusionPipeline): | |
| def __init__(self, model, scheduler): | |
| super().__init__() | |
| self.register_modules(model=model, scheduler=scheduler) | |
| def __call__(self, init_pc, force, E, nu, mask, drag_point, floor_height, gravity, coeff, | |
| generator, | |
| device, | |
| y = None, | |
| batch_size: int = 1, | |
| num_inference_steps: int = 50, | |
| guidance_scale=1.0, | |
| n_frames=20 | |
| ): | |
| # Sample gaussian noise to begin loop | |
| sample = torch.randn((batch_size, n_frames, init_pc.shape[2], 3), generator=generator).to(device) | |
| self.model.to(device) | |
| init_pc = init_pc.to(device) | |
| force = force.to(device) | |
| E = E.to(device) | |
| nu = nu.to(device) | |
| mask = mask.to(device).to(dtype=sample.dtype) | |
| drag_point = drag_point.to(device) | |
| floor_height = floor_height.to(device) | |
| coeff = coeff.to(device) | |
| gravity = gravity.to(device) if gravity is not None else None | |
| y = y.to(device) if y is not None else None | |
| # set step values | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| do_classifier_free_guidance = (guidance_scale > 1.0) | |
| null_emb = torch.tensor([1] * batch_size).to(sample.dtype) | |
| if do_classifier_free_guidance: | |
| init_pc = torch.cat([init_pc] * 2) | |
| force = torch.cat([force] * 2) | |
| E = torch.cat([E] * 2) | |
| nu = torch.cat([nu] * 2) | |
| mask = torch.cat([mask] * 2) | |
| drag_point = torch.cat([drag_point] * 2) | |
| floor_height = torch.cat([floor_height] * 2) | |
| null_emb = torch.cat([torch.tensor([0] * batch_size).to(sample.dtype), null_emb]) | |
| null_emb = null_emb[:, None, None].to(device) | |
| for t in self.progress_bar(self.scheduler.timesteps): | |
| t = torch.tensor([t] * batch_size, device=device) | |
| sample_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample | |
| t = torch.cat([t] * 2) if do_classifier_free_guidance else t | |
| # 1. predict noise model_output | |
| model_output = self.model(sample_input, t, init_pc, force, E, nu, mask, drag_point, floor_height=floor_height, gravity_label=gravity, coeff=coeff, y=y, null_emb=null_emb) | |
| if do_classifier_free_guidance: | |
| model_pred_uncond, model_pred_cond = model_output.chunk(2) | |
| model_output = model_pred_uncond + guidance_scale * (model_pred_cond - model_pred_uncond) | |
| # 2. predict previous mean of image x_t-1 and add variance depending on eta | |
| # eta corresponds to η in paper and should be between [0, 1] | |
| # do x_t -> x_t-1 | |
| sample = self.scheduler.step(model_output, t[0], sample).prev_sample | |
| return sample |