YiYiXu HF Staff commited on
Commit
0b28f24
·
verified ·
1 Parent(s): 747669c

Update block.py

Browse files
Files changed (1) hide show
  1. block.py +54 -2
block.py CHANGED
@@ -1,12 +1,13 @@
1
  from typing import List
2
  import torch
3
 
4
- from diffusers.pipelines.modular_pipeline import PipelineState, PipelineBlock
5
  from diffusers.pipelines.modular_pipeline_utils import (
6
  InputParam,
7
  ComponentSpec,
8
  OutputParam,
9
  )
 
10
  from diffusers.image_processor import PipelineImageInput
11
  from image_gen_aux import DepthPreprocessor
12
 
@@ -28,11 +29,20 @@ class DepthProcessorBlock(PipelineBlock):
28
  InputParam(
29
  "image",
30
  PipelineImageInput,
31
- required=True,
32
  description="Image(s) to use to extract depth maps",
33
  )
34
  ]
35
 
 
 
 
 
 
 
 
 
 
 
36
  @property
37
  def intermediates_outputs(self) -> List[OutputParam]:
38
  return [
@@ -53,3 +63,45 @@ class DepthProcessorBlock(PipelineBlock):
53
 
54
  self.add_block_state(state, block_state)
55
  return pipeline, state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List
2
  import torch
3
 
4
+ from diffusers.pipelines.modular_pipeline import PipelineState, PipelineBlock, SequentialPipelineBlocks, AutoPipelineBlocks
5
  from diffusers.pipelines.modular_pipeline_utils import (
6
  InputParam,
7
  ComponentSpec,
8
  OutputParam,
9
  )
10
+ from diffusers.utils import load_image
11
  from diffusers.image_processor import PipelineImageInput
12
  from image_gen_aux import DepthPreprocessor
13
 
 
29
  InputParam(
30
  "image",
31
  PipelineImageInput,
 
32
  description="Image(s) to use to extract depth maps",
33
  )
34
  ]
35
 
36
+ @property
37
+ def intermediates_inputs(self) -> List[InputParam]:
38
+ return [
39
+ InputParam(
40
+ "image",
41
+ PipelineImageInput,
42
+ description="Image(s) to use to extract depth maps, can be output from LoadURL block",
43
+ )
44
+ ]
45
+
46
  @property
47
  def intermediates_outputs(self) -> List[OutputParam]:
48
  return [
 
63
 
64
  self.add_block_state(state, block_state)
65
  return pipeline, state
66
+
67
+ class LoadURL(PipelineBlock):
68
+
69
+ @property
70
+ def inputs(self) -> List[InputParam]:
71
+ return [
72
+ InputParam(
73
+ "url",
74
+ str,
75
+ )
76
+ ]
77
+
78
+ @property
79
+ def intermediates_outputs(self) -> List[OutputParam]:
80
+ return [
81
+ OutputParam(
82
+ "image",
83
+ type_hint=PipelineImageInput,
84
+ description="Image(s) to use to extract depth maps",
85
+ ),
86
+ ]
87
+
88
+ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
89
+ block_state = self.get_block_state(state)
90
+ block_state.image = load_image(blck_state.url)
91
+ self.add_block_state(state, block_state)
92
+ return pipeline, state
93
+
94
+ class AutoLoadURL(AutoPipelineBlocks):
95
+ block_classes = [LoadURL]
96
+ block_name = ["url_to_image"]
97
+ block_trigger_inputs = ["url"]
98
+
99
+ @property
100
+ def description(self):
101
+ return "Run if `url` is provided."
102
+
103
+ class DepthInput(SequentialPipelineBlocks)
104
+ block_classes = [AutoLoadURL, DepthProcessorBlock]
105
+ block_names = ["load_url", "depth_processor"]
106
+
107
+