File size: 3,123 Bytes
d9221ed 206f874 d9221ed 0b28f24 747669c d9221ed f580baa d9221ed 0b28f24 d9221ed 80dc779 d9221ed 462eb62 80dc779 d9221ed 0b28f24 66b6384 0b28f24 2029d30 0b28f24 2a2b8e1 0b28f24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from typing import List
import torch
from diffusers.modular_pipelines import PipelineState, PipelineBlock, SequentialPipelineBlocks, AutoPipelineBlocks
from diffusers.modular_pipelines.modular_pipeline_utils import (
InputParam,
ComponentSpec,
OutputParam,
)
from diffusers.utils import load_image
from diffusers.image_processor import PipelineImageInput
from image_gen_aux import DepthPreprocessor
class DepthProcessorBlock(PipelineBlock):
@property
def expected_components(self):
return [
ComponentSpec(
name="depth_processor",
type_hint=DepthPreprocessor,
subfolder="",
repo="depth-anything/Depth-Anything-V2-Large-hf",
)
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
PipelineImageInput,
description="Image(s) to use to extract depth maps",
)
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
PipelineImageInput,
description="Image(s) to use to extract depth maps, can be output from LoadURL block",
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image",
type_hint=torch.Tensor,
description="Depth Map(s) of input Image(s)",
),
]
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = pipeline._execution_device
image = block_state.image
depth_map = pipeline.depth_processor(image, return_type="pt")
block_state.image = depth_map.to(device)
self.add_block_state(state, block_state)
return pipeline, state
class LoadURL(PipelineBlock):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"url",
str,
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image",
type_hint=PipelineImageInput,
description="Image(s) to use to extract depth maps",
),
]
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.image = load_image(block_state.url)
self.add_block_state(state, block_state)
return pipeline, state
class AutoLoadURL(AutoPipelineBlocks):
block_classes = [LoadURL]
block_names = ["url_to_image"]
block_trigger_inputs = ["url"]
@property
def description(self):
return "Run if `url` is provided."
class DepthInput(SequentialPipelineBlocks):
block_classes = [AutoLoadURL, DepthProcessorBlock]
block_names = ["load_url", "depth_processor"]
|