Upload pipeline.py with huggingface_hub
Browse files- pipeline.py +55 -0
pipeline.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import DiffusionPipeline, DDPMScheduler
|
| 5 |
+
from diffusers.utils import BaseOutput
|
| 6 |
+
|
| 7 |
+
from src.pipelines.point_navigation.components.model import TrajectoryDiffusionModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TrajectoryPipelineOutput(BaseOutput):
|
| 11 |
+
trajectories: torch.Tensor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PointNavigationPipeline(DiffusionPipeline):
|
| 15 |
+
model: TrajectoryDiffusionModel
|
| 16 |
+
scheduler: DDPMScheduler
|
| 17 |
+
|
| 18 |
+
def __init__(self, model: TrajectoryDiffusionModel, scheduler: DDPMScheduler):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.register_modules(model=model, scheduler=scheduler)
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def __call__(
|
| 24 |
+
self,
|
| 25 |
+
start: Union[List[float], Tuple[float, float]],
|
| 26 |
+
target: Union[List[float], Tuple[float, float]],
|
| 27 |
+
batch_size: int = 1,
|
| 28 |
+
num_inference_steps: int = 1000,
|
| 29 |
+
generator: Optional[torch.Generator] = None,
|
| 30 |
+
) -> TrajectoryPipelineOutput:
|
| 31 |
+
device = self.device
|
| 32 |
+
|
| 33 |
+
observation = torch.tensor(
|
| 34 |
+
[[start[0], start[1], target[0], target[1]]] * batch_size,
|
| 35 |
+
device=device,
|
| 36 |
+
dtype=torch.float32
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
trajectory = torch.randn(
|
| 40 |
+
(batch_size, 32, 2),
|
| 41 |
+
device=device,
|
| 42 |
+
generator=generator,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 46 |
+
|
| 47 |
+
for t in self.scheduler.timesteps:
|
| 48 |
+
noise_pred = self.model(
|
| 49 |
+
trajectory,
|
| 50 |
+
torch.tensor([t] * batch_size, device=device),
|
| 51 |
+
observation,
|
| 52 |
+
)
|
| 53 |
+
trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample
|
| 54 |
+
|
| 55 |
+
return TrajectoryPipelineOutput(trajectories=trajectory)
|