Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from diffusers import DDPMScheduler, DDIMScheduler | |
| from dataset.traj_dataset import TrajDataset | |
| from model.mdm import MDM | |
| from model.mdm_dit import MDM_DiT | |
| from model.spacetime import MDM_ST | |
| import sys | |
| from options import TrainingConfig, TestingConfig | |
| from omegaconf import OmegaConf | |
| from pipeline_traj import TrajPipeline | |
| import torch | |
| from safetensors.torch import load_file | |
| import argparse | |
| import os | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from eval import create_model | |
| from tqdm import tqdm | |
| import numpy as np | |
| from utils.visualization import save_pointcloud_video, save_pointcloud_json, save_threejs_html | |
| import matplotlib.pyplot as plt | |
| def fibonacci_sphere(n): | |
| i = torch.arange(n, dtype=torch.float32) | |
| phi = 2 * torch.pi * i / ((1 + 5**0.5) / 2) # golden‑angle | |
| z = 1 - 2 * (i + 0.5) / n # uniform in [-1,1] | |
| r_xy = (1 - z**2).sqrt() | |
| x = r_xy * torch.cos(phi) | |
| y = r_xy * torch.sin(phi) | |
| return torch.stack((x, y, z), dim=1) # shape (n,3) | |
| class Inferrer: | |
| def __init__(self, args, device='cuda'): | |
| self.args = args | |
| self.device = device | |
| self.model = create_model(args).to(device) | |
| ckpt = load_file(args.resume, device='cpu') | |
| self.model.load_state_dict(ckpt, strict=False) | |
| self.model.eval().requires_grad_(False).to(device) | |
| self.scheduler = DDIMScheduler(num_train_timesteps=1000, prediction_type='sample', clip_sample=False) | |
| self.pipeline = TrajPipeline(model=self.model, scheduler=self.scheduler) | |
| def probe_params(self, init_pc, force, motion_obs, mask, drag_point, floor_height, coeff, y, vis_dir=None, fname=None): | |
| out = [] | |
| for e in torch.arange(4.0, 7.1, 0.5): | |
| # for n in torch.arange(0.2, 0.45, 0.05): | |
| # for n in [0.36]: | |
| E, nu = torch.tensor([e], device=self.device).reshape(1, 1), torch.tensor([n], device=self.device).reshape(1, 1) | |
| motion_pred = self.pipeline(init_pc, force, E, nu, mask, drag_point, floor_height, gravity=None, coeff=coeff, y=y, device=self.device, batch_size=1, generator=torch.Generator().manual_seed(self.args.seed), n_frames=self.args.train_dataset.n_training_frames, num_inference_steps=25) | |
| loss = F.mse_loss(motion_pred, motion_obs.to(self.device)) | |
| out.append([loss, e, n]) | |
| # save_pointcloud_video(motion_pred.squeeze().cpu().numpy(), motion_obs.squeeze().cpu().numpy(), os.path.join(f'{e.item():03f}_{nu.item():02f}.gif'), drag_mask=mask[:1, 0, :, 0].cpu().numpy().squeeze(), vis_flag='objaverse') | |
| out = torch.tensor(out).cpu().numpy() | |
| print("Best E, nu: ", out[np.argmin(out[:, 0])]) | |
| plt.plot(out[:, 1], out[:, 0], marker='o', linestyle='-', linewidth=2) | |
| plt.xlabel('E') | |
| plt.ylabel('Loss') | |
| plt.savefig(os.path.join(vis_dir, f'{fname}.png')) | |
| plt.close() | |
| return out | |
| def forward_model(self, motion_noisy, t, init_pc, force, E, nu, mask, guidance_scale=1.0): | |
| bsz = motion_noisy.shape[0] | |
| null_emb = torch.tensor([1] * motion_noisy.shape[0]).to(motion_noisy.dtype) | |
| if cfg > 1.0: | |
| motion_noisy = torch.cat([motion_noisy] * 2) | |
| init_pc = torch.cat([init_pc] * 2) | |
| force = torch.cat([force] * 2) | |
| E = torch.cat([E] * 2) | |
| nu = torch.cat([nu] * 2) | |
| t = torch.cat([t] * 2) | |
| mask = torch.cat([mask] * 2) | |
| null_emb = torch.cat([torch.tensor([0] * bsz).to(motion_noisy.dtype), null_emb]) | |
| null_emb = null_emb[:, None, None].to(self.device, dtype=motion_noisy.dtype) | |
| model_output = self.model(motion_noisy, t, init_pc, force, E, nu, mask) | |
| if cfg > 1.0: | |
| model_pred_uncond, model_pred_cond = model_output.chunk(2) | |
| model_output = model_pred_uncond + guidance_scale * (model_pred_cond - model_pred_uncond) | |
| return model_output | |
| def inference_model(self, init_pc, force, E, nu, mask, drag_point, floor_height, coeff, | |
| generator, | |
| device, | |
| batch_size: int = 1, | |
| num_inference_steps: int = 50, | |
| guidance_scale=1.0, | |
| n_frames=20 | |
| ): | |
| # Sample gaussian noise to begin loop | |
| sample = torch.randn((batch_size, n_frames, init_pc.shape[2], 3), generator=generator).to(device) | |
| # set step values | |
| self.scheduler.set_timesteps(num_inference_steps) | |
| do_classifier_free_guidance = (guidance_scale > 1.0) | |
| null_emb = torch.tensor([1] * batch_size).to(sample.dtype) | |
| if do_classifier_free_guidance: | |
| init_pc = torch.cat([init_pc] * 2) | |
| force = torch.cat([force] * 2) | |
| E = torch.cat([E] * 2) | |
| nu = torch.cat([nu] * 2) | |
| mask = torch.cat([mask] * 2) | |
| drag_point = torch.cat([drag_point] * 2) | |
| floor_height = torch.cat([floor_height] * 2) | |
| null_emb = torch.cat([torch.tensor([0] * batch_size).to(sample.dtype), null_emb]) | |
| null_emb = null_emb[:, None, None].to(device) | |
| for t in self.scheduler.timesteps: | |
| t = torch.tensor([t] * batch_size, device=device) | |
| sample_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample | |
| t = torch.cat([t] * 2) if do_classifier_free_guidance else t | |
| # 1. predict noise model_output | |
| model_output = self.model(sample_input, t, init_pc, force, E, nu, mask, drag_point, floor_height, coeff, y=y, null_emb=null_emb) | |
| if do_classifier_free_guidance: | |
| model_pred_uncond, model_pred_cond = model_output.chunk(2) | |
| model_output = model_pred_uncond + guidance_scale * (model_pred_cond - model_pred_uncond) | |
| sample = self.scheduler.step(model_output, t[0], sample).prev_sample | |
| return sample | |
| def estimate_params(self, model_name, motion_obs, init_pc, force, mask, drag_point, floor_height, coeff, y, cfg=1.0, gravity=None, probe=False, num_steps=400): | |
| device = 'cuda' | |
| all_loss = [] | |
| if probe: | |
| out = [] | |
| for e in torch.arange(4.0, 7.1, 0.5): | |
| E = torch.tensor([e], device=self.device).reshape(1, 1) | |
| motion_pred = self.pipeline(init_pc, force, E, nu, mask, drag_point, floor_height, gravity=gravity, coeff=coeff, y=y, device=self.device, batch_size=1, generator=torch.Generator().manual_seed(self.args.seed), n_frames=self.args.train_dataset.n_training_frames, num_inference_steps=25) | |
| loss = F.mse_loss(motion_pred, motion_obs.to(self.device)) | |
| out.append([loss.item(), E.item()]) | |
| out = torch.tensor(out) | |
| print("Best E, nu: ", out[torch.argmin(out[:, 0])]) | |
| E = nn.Parameter(torch.tensor([out[np.argmin(out[:, 0]), 1]], device=device).reshape(1, 1)) | |
| all_loss.append(out[torch.argmin(out[:, 0])]) | |
| else: | |
| E = nn.Parameter(torch.tensor([4.5], device=device).reshape(1, 1)) | |
| # nu = nn.Parameter(torch.tensor([0.15], device=device).reshape(1, 1)) | |
| # force = nn.Parameter(torch.zeros([1, 3], device=device)) | |
| # drag_point = nn.Parameter(torch.zeros([1, 3], device=device)) | |
| optimizer = torch.optim.Adam([ | |
| {'params': E, 'lr': 1e-2, 'min': 4.0, 'max': 7.0}, | |
| # {'params': nu, 'lr': 1e-2, 'min': 0.15, 'max': 0.4}, | |
| # {'params': [force, drag_point], 'lr': 1e-2} | |
| ]) | |
| self.model.requires_grad_(True) | |
| progress_bar = tqdm(total=num_steps) | |
| progress_bar.set_description("Training") | |
| Es = [] | |
| for step in range(num_steps): | |
| optimizer.zero_grad() | |
| noise = torch.randn_like(motion_obs, device=device) | |
| t = torch.randint(0, self.scheduler.num_train_timesteps, (motion_obs.shape[0],), device=device) | |
| motion_noisy = self.scheduler.add_noise(motion_obs, noise, t) | |
| model_output = self.model(motion_noisy, t, init_pc, force, E, nu, mask, drag_point, floor_height, gravity, coeff, y=y) | |
| loss = F.mse_loss(model_output, motion_obs) | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({'loss': loss.item(), 'E': E.item(), 'nu': nu.item()}) | |
| loss.backward() | |
| optimizer.step() | |
| with torch.no_grad(): | |
| E.clamp_(4.0, 7.0) | |
| if (step + 1) % 200 == 0: | |
| Es.append(E.item()) | |
| if (step + 1) % 200 == 0: | |
| motion_pred = self.pipeline(init_pc, force.detach(), E.detach(), nu.detach(), mask, drag_point.detach(), floor_height, gravity=gravity, coeff=coeff, y=y, device=self.device, batch_size=1, generator=torch.Generator().manual_seed(self.args.seed), n_frames=self.args.train_dataset.n_training_frames, num_inference_steps=25) | |
| loss = F.mse_loss(motion_pred, motion_obs) | |
| all_loss.append(torch.tensor([loss.item(), E.item()])) | |
| out = torch.stack(all_loss) | |
| print(out) | |
| E = out[torch.argmin(out[:, 0]), 1].to(device) | |
| Es.append(E.item()) | |
| save_pointcloud_video(motion_pred.squeeze().cpu().numpy(), motion_obs.squeeze().cpu().numpy(), os.path.join(f'./debug/v3', f'{model_name}_{E.item():03f}_{nu.item():02f}.gif'), drag_mask=mask[:1, 0, :, 0].cpu().numpy().squeeze(), vis_flag='objaverse') | |
| return Es, nu, force, drag_point | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, required=True) | |
| args = parser.parse_args() | |
| schema = OmegaConf.structured(TestingConfig) | |
| cfg = OmegaConf.load(args.config) | |
| args = OmegaConf.merge(schema, cfg) | |
| val_dataset = TrajDataset('val', args.train_dataset) | |
| # val_dataset = [val_dataset[i] for i in range(len(val_dataset) - 15, len(val_dataset))] | |
| val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.dataloader_num_workers) | |
| inferrer = Inferrer(args) | |
| loss = 0.0 | |
| loss_mask = 0.0 | |
| probe = True | |
| num_steps = 400 | |
| for i, (batch, _) in enumerate(val_dataloader): | |
| device = torch.device('cuda') | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| model_name = batch['model'][0] | |
| motion_obs = batch['points_tgt'].to(device) | |
| init_pc = batch['points_src'].to(device) | |
| force = batch['force'].to(device) | |
| E = batch['E'].to(device) | |
| nu = batch['nu'].to(device) | |
| mask = batch['mask'][..., :1].to(device, dtype=force.dtype) | |
| drag_point = batch['drag_point'].to(device) | |
| floor_height = batch['floor_height'].to(device) | |
| coeff = batch['base_drag_coeff'] | |
| y=None if 'mat_type' not in batch else batch['mat_type'].to(device) | |
| gravity = batch['gravity'].to(device) if 'gravity' in batch else None | |
| print(model_name, floor_height) | |
| # for j in range(output.shape[0]): | |
| # save_pointcloud_video(motion_obs.squeeze().cpu().numpy(), motion_obs.squeeze().cpu().numpy(), os.path.join('./debug', f'{i:03d}_{E.item():03f}_{nu.item():02f}.gif'), drag_mask=mask[:1, 0, :, 0].cpu().numpy().squeeze(), vis_flag='objaverse') | |
| print('GT', E, nu, drag_point, force, y) | |
| est_E, est_nu, est_f, est_d = inferrer.estimate_params(model_name, motion_obs.to(device), init_pc, force, mask, drag_point, floor_height, coeff, y=y, cfg=1.0, gravity=gravity, probe=probe, num_steps=num_steps) | |
| # print(f'EST_{model_name}', F.mse_loss(est_E, E), F.mse_loss(est_nu, nu), F.mse_loss(est_d[..., :3], drag_point[..., :3]), F.mse_loss(est_f, force)) | |
| est_E = ','.join([f'{e:.3f}' for e in est_E]) if isinstance(est_E, list) else est_E.item() | |
| print(est_E) | |
| with open(os.path.join('./debug', f'output_probe{probe}_steps{num_steps}.txt'), 'a+') as f: | |
| f.write(f'{model_name},{E.item()},{est_E},{nu.item()},{est_nu.item()},{drag_point.cpu().numpy()},{est_d.cpu().numpy()},{force.cpu().numpy()},{est_f.cpu().numpy()}\n') | |
| # break |