from typing import List, Optional, Tuple, Union import torch from diffusers import DiffusionPipeline, DDPMScheduler from diffusers.utils import BaseOutput from src.pipelines.point_navigation.components.model import TrajectoryDiffusionModel class TrajectoryPipelineOutput(BaseOutput): trajectories: torch.Tensor class PointNavigationPipeline(DiffusionPipeline): model: TrajectoryDiffusionModel scheduler: DDPMScheduler def __init__(self, model: TrajectoryDiffusionModel, scheduler: DDPMScheduler): super().__init__() self.register_modules(model=model, scheduler=scheduler) @torch.no_grad() def __call__( self, start: Union[List[float], Tuple[float, float]], target: Union[List[float], Tuple[float, float]], batch_size: int = 1, num_inference_steps: int = 1000, generator: Optional[torch.Generator] = None, ) -> TrajectoryPipelineOutput: device = self.device observation = torch.tensor( [[start[0], start[1], target[0], target[1]]] * batch_size, device=device, dtype=torch.float32 ) trajectory = torch.randn( (batch_size, 32, 2), device=device, generator=generator, ) self.scheduler.set_timesteps(num_inference_steps) for t in self.scheduler.timesteps: noise_pred = self.model( trajectory, torch.tensor([t] * batch_size, device=device), observation, ) trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample return TrajectoryPipelineOutput(trajectories=trajectory)