hvent90 commited on
Commit
85436ff
·
verified ·
1 Parent(s): 588f817

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +55 -0
pipeline.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from diffusers import DiffusionPipeline, DDPMScheduler
5
+ from diffusers.utils import BaseOutput
6
+
7
+ from src.pipelines.point_navigation.components.model import TrajectoryDiffusionModel
8
+
9
+
10
+ class TrajectoryPipelineOutput(BaseOutput):
11
+ trajectories: torch.Tensor
12
+
13
+
14
+ class PointNavigationPipeline(DiffusionPipeline):
15
+ model: TrajectoryDiffusionModel
16
+ scheduler: DDPMScheduler
17
+
18
+ def __init__(self, model: TrajectoryDiffusionModel, scheduler: DDPMScheduler):
19
+ super().__init__()
20
+ self.register_modules(model=model, scheduler=scheduler)
21
+
22
+ @torch.no_grad()
23
+ def __call__(
24
+ self,
25
+ start: Union[List[float], Tuple[float, float]],
26
+ target: Union[List[float], Tuple[float, float]],
27
+ batch_size: int = 1,
28
+ num_inference_steps: int = 1000,
29
+ generator: Optional[torch.Generator] = None,
30
+ ) -> TrajectoryPipelineOutput:
31
+ device = self.device
32
+
33
+ observation = torch.tensor(
34
+ [[start[0], start[1], target[0], target[1]]] * batch_size,
35
+ device=device,
36
+ dtype=torch.float32
37
+ )
38
+
39
+ trajectory = torch.randn(
40
+ (batch_size, 32, 2),
41
+ device=device,
42
+ generator=generator,
43
+ )
44
+
45
+ self.scheduler.set_timesteps(num_inference_steps)
46
+
47
+ for t in self.scheduler.timesteps:
48
+ noise_pred = self.model(
49
+ trajectory,
50
+ torch.tensor([t] * batch_size, device=device),
51
+ observation,
52
+ )
53
+ trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample
54
+
55
+ return TrajectoryPipelineOutput(trajectories=trajectory)