zimhe commited on
Commit
a521a3f
·
1 Parent(s): 5790b69

add scripts and examples

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
examples/examples.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Mountain Fall": {
3
+ "img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/mount-fall.jpg",
4
+ "global": "large wide angle landscape view of rocky mountains at Fall, water lake and forest at the mountain root, blue sky at dusk with scattered golden clouds",
5
+ "front": "view of towering mountains and rivers, with dim clear foggy sky",
6
+ "back": "low mountains, scattered and layered, bright golden cloudy sky",
7
+ "left": "layered mountain peaks with thin cyan and green and creamy color,bright cloudy sky",
8
+ "right": "green water blending with the sky, bright cloudy sky, soft light",
9
+ "top": "bright beiged sky color, thick scattered clouds, faded to background",
10
+ "bottom": "earthy ground with dry grass and scattered stones"
11
+ },
12
+ "Qianli": {
13
+ "img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/qianli.png",
14
+ "global": "Chinese traditional landscape painting, aged beiged color scheme, thin and subtle warm colors, mountains with warm color bottom and with peaks tint thin cyan and green, bright cloudy sky with soft light highly detailed, masterpiece.",
15
+ "front": "view of towering mountains and rivers, with dim clear foggy sky",
16
+ "back": "low mountains, scattered and layered, bright golden cloudy sky",
17
+ "left": "layered mountain peaks with thin cyan and green and creamy color,bright cloudy sky",
18
+ "right": "green water blending with the sky, bright cloudy sky, soft light",
19
+ "top": "bright beiged sky color, thick scattered clouds, faded to background",
20
+ "bottom": "earthy ground with dry grass and scattered stones"
21
+ }
22
+ ,
23
+ "Office":{
24
+ "img":"examples/office.png",
25
+ "global": "A modernist open office interior by Louis Khan, low view angle, simple clean structures, spacious and expansive, natural light, concrete and warm wood material, view of the city, high quality, 8k resolution, photorealistic, ultra-detailed, Masterpiece.",
26
+ "front": "office interior with delicate workplace desks",
27
+ "back": "office area with many desks and chairs",
28
+ "left": "large glass partition with dark metal frames",
29
+ "right": "public office area with beautiful curvy structures",
30
+ "top": "clean white ceiling, linear continuous lights",
31
+ "bottom": "clean empty floor with view from above, clean gray gloss material"
32
+ },
33
+ "Forest":{
34
+ "img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/forest.jpg",
35
+ "global": "within the forest with beautiful weaving trees, a tiger inside the bush, small water flowing, very thick tree trunks.",
36
+ "front": "view of thriving forest, with branches covering the sky",
37
+ "back": "a small pathway leading to a small water stream, with trees on both sides",
38
+ "left": "short bushes and trees,",
39
+ "right": "Stone-made pathways, with green leaves and branches, bushes surrounding the bottom",
40
+ "top": "crossing branches and leaves, with a few rays of light peeking through",
41
+ "bottom": "earthy ground with grass and scattered stones and twigs, with a few leaves on the ground"
42
+ },
43
+
44
+ "Las-Meninas":{
45
+ "img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/Las-meninas.png",
46
+ "global": "A 17th-century Spanish royal studio interior with baroque decor, soft natural light, a painter at work, a young princess in a white dress, maids of honor, and a calm dog. Classic, elegant, and atmospheric.",
47
+ "front": "The princess in the center with her maids, the painter beside a large canvas, and a dog lying on the floor. Bright, central focus.",
48
+ "back": "Back of the studio with walls full of paintings and dim shadows. Quiet and deep space.",
49
+ "left": "Side view of the large canvas and the painter, with soft light from the opposite side.",
50
+ "right": "Ladies-in-waiting from the side, detailed dresses, dog in profile, and portraits on the dark wall.",
51
+ "top": "Ceiling with wooden beams and chandeliers, lit softly.",
52
+ "bottom": "Floor with shadows, the princess's gown, and the resting dog."
53
+ },
54
+ "Van-Gogh":{
55
+ "img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/vangoh.png",
56
+ "global": "A cozy bedroom in the style of Vincent van Gogh, The scene features rustic wooden furniture, a bed with a red pillow and yellow blanket,",
57
+ "front": "The Bedroom in Arles by Vincent van Gogh",
58
+ "back": "wall with door and window, with a view of the outside",
59
+ "left": "wooden chairs with yellow cushions, a small wooden table with a water pitcher and a mirror above it,",
60
+ "right": "The walls are painted with a vibrant blue color, with paintings and decorations hanging slightly tilted",
61
+ "top": "a white ceiling with an old lighting, a window with green shutters letting in warm daylight",
62
+ "bottom": "The wooden floor has a textured, painterly look"
63
+ }
64
+
65
+ }
scripts/__pycache__/analysis.cpython-312.pyc ADDED
Binary file (3.94 kB). View file
 
scripts/__pycache__/cubemap_dataset.cpython-310.pyc ADDED
Binary file (2.53 kB). View file
 
scripts/__pycache__/cubemap_dataset.cpython-312.pyc ADDED
Binary file (7.77 kB). View file
 
scripts/__pycache__/cubemap_diffusion_pipeline.cpython-312.pyc ADDED
Binary file (26.1 kB). View file
 
scripts/__pycache__/cubemap_unet.cpython-310.pyc ADDED
Binary file (3.33 kB). View file
 
scripts/__pycache__/cubemap_unet.cpython-312.pyc ADDED
Binary file (3.67 kB). View file
 
scripts/__pycache__/cubemap_unet_attention.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
scripts/__pycache__/cubemap_unet_attention.cpython-312.pyc ADDED
Binary file (4.77 kB). View file
 
scripts/__pycache__/cubemap_unet_attention_processor.cpython-312.pyc ADDED
Binary file (3.99 kB). View file
 
scripts/__pycache__/cubemap_unet_attention_simple.cpython-312.pyc ADDED
Binary file (3.87 kB). View file
 
scripts/__pycache__/cubemap_vae.cpython-310.pyc ADDED
Binary file (6.38 kB). View file
 
scripts/__pycache__/cubemap_vae.cpython-312.pyc ADDED
Binary file (8.71 kB). View file
 
scripts/__pycache__/preprocess.cpython-310.pyc ADDED
Binary file (2.7 kB). View file
 
scripts/__pycache__/train_cubemap_args.cpython-310.pyc ADDED
Binary file (8.46 kB). View file
 
scripts/__pycache__/train_cubemap_args.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
scripts/__pycache__/train_cubemap_unet.cpython-312.pyc ADDED
Binary file (31.7 kB). View file
 
scripts/__pycache__/utils.cpython-310.pyc ADDED
Binary file (553 Bytes). View file
 
scripts/__pycache__/utils.cpython-312.pyc ADDED
Binary file (26.9 kB). View file
 
scripts/cubemap_diffusion_pipeline.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+ from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline,StableDiffusionPipelineOutput
4
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
5
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
6
+ from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
7
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import retrieve_timesteps
8
+ from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
9
+ from .utils import FACES,generate_cubemap_uv
10
+ from diffusers import AutoPipelineForInpainting
11
+
12
+ DEFAULT_VIEW_PROMPT = {
13
+ "front": "view to the front, looking forward",
14
+ "back": "view to the back, looking backward",
15
+ "left": "view to the left side, looking left",
16
+ "right": "view to the right side, looking right",
17
+ "top": "view to above, looking upward, ceiling or sky",
18
+ "bottom": "view to below, looking downward, floor or ground",
19
+ }
20
+
21
+
22
+
23
+ class CubemapDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
24
+ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, image_encoder = None, requires_safety_checker = True):
25
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, image_encoder, requires_safety_checker)
26
+
27
+
28
+ def prepare_cubemap_mask_condition(self,width: int, height: int, batch_size=6) -> torch.Tensor:
29
+ # 初始化全部为 0 的 tensor,形状为 (6, 1, height, width)
30
+ mask = torch.ones(size=(batch_size, 1, height, width), dtype=torch.float32)
31
+ # 将第一张 mask 置为 1
32
+ mask[0] = 0
33
+ return mask
34
+
35
+ @torch.no_grad()
36
+ def __call__(
37
+ self,
38
+ global_prompt: str = None,
39
+ per_face_prompts: Optional[Dict[str, str]] = None,
40
+ image: PipelineImageInput = None,
41
+ mask_image: PipelineImageInput = None,
42
+ masked_image_latents: torch.Tensor = None,
43
+ height: Optional[int] = None,
44
+ width: Optional[int] = None,
45
+ padding_mask_crop: Optional[int] = None,
46
+ strength: float = 1.0,
47
+ num_inference_steps: int = 50,
48
+ timesteps: List[int] = None,
49
+ sigmas: List[float] = None,
50
+ guidance_scale: float = 7.5,
51
+ negative_prompt: Optional[Union[str, List[str]]] = None,
52
+ num_images_per_prompt: Optional[int] = 1,
53
+ eta: float = 0.0,
54
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
55
+ latents: Optional[torch.Tensor] = None,
56
+ prompt_embeds: Optional[torch.Tensor] = None,
57
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
58
+ ip_adapter_image: Optional[PipelineImageInput] = None,
59
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
60
+ output_type: Optional[str] = "pil",
61
+ return_dict: bool = True,
62
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
63
+ clip_skip: int = None,
64
+ callback_on_step_end: Optional[
65
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
66
+ ] = None,
67
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
68
+ **kwargs,
69
+ ):
70
+ r"""
71
+ The call function to the pipeline for generation.
72
+
73
+ Args:
74
+ prompt (`str` or `List[str]`, *optional*):
75
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
76
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
77
+ `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
78
+ be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
79
+ tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
80
+ expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
81
+ expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
82
+ if passing latents directly it is not encoded again.
83
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
84
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
85
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
86
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
87
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
88
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
89
+ 1)`, or `(H, W)`.
90
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
91
+ The height in pixels of the generated image.
92
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
93
+ The width in pixels of the generated image.
94
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
95
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
96
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
97
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
98
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
99
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
100
+ the image is large and contain information irrelevant for inpainting, such as background.
101
+ strength (`float`, *optional*, defaults to 1.0):
102
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
103
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
104
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
105
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
106
+ essentially ignores `image`.
107
+ num_inference_steps (`int`, *optional*, defaults to 50):
108
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
109
+ expense of slower inference. This parameter is modulated by `strength`.
110
+ timesteps (`List[int]`, *optional*):
111
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
112
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
113
+ passed will be used. Must be in descending order.
114
+ sigmas (`List[float]`, *optional*):
115
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
116
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
117
+ will be used.
118
+ guidance_scale (`float`, *optional*, defaults to 7.5):
119
+ A higher guidance scale value encourages the model to generate images closely linked to the text
120
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
121
+ negative_prompt (`str` or `List[str]`, *optional*):
122
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
123
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
124
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
125
+ The number of images to generate per prompt.
126
+ eta (`float`, *optional*, defaults to 0.0):
127
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
128
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
129
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
130
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
131
+ generation deterministic.
132
+ latents (`torch.Tensor`, *optional*):
133
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
134
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
135
+ tensor is generated by sampling using the supplied random `generator`.
136
+ prompt_embeds (`torch.Tensor`, *optional*):
137
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
138
+ provided, text embeddings are generated from the `prompt` input argument.
139
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
140
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
141
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
142
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
143
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
144
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
145
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
146
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
147
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
148
+ output_type (`str`, *optional*, defaults to `"pil"`):
149
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
150
+ return_dict (`bool`, *optional*, defaults to `True`):
151
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
152
+ plain tuple.
153
+ cross_attention_kwargs (`dict`, *optional*):
154
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
155
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
156
+ clip_skip (`int`, *optional*):
157
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
158
+ the output of the pre-final layer will be used for computing the prompt embeddings.
159
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
160
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
161
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
162
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
163
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
164
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
165
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
166
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
167
+ `._callback_tensor_inputs` attribute of your pipeline class.
168
+ Examples:
169
+
170
+ ```py
171
+ >>> import PIL
172
+ >>> import requests
173
+ >>> import torch
174
+ >>> from io import BytesIO
175
+
176
+ >>> from diffusers import StableDiffusionInpaintPipeline
177
+
178
+
179
+ >>> def download_image(url):
180
+ ... response = requests.get(url)
181
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
182
+
183
+
184
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
185
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
186
+
187
+ >>> init_image = download_image(img_url).resize((512, 512))
188
+ >>> mask_image = download_image(mask_url).resize((512, 512))
189
+
190
+ >>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
191
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
192
+ ... )
193
+ >>> pipe = pipe.to("cuda")
194
+
195
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
196
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
197
+ ```
198
+
199
+ Returns:
200
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
201
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
202
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
203
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
204
+ "not-safe-for-work" (nsfw) content.
205
+ """
206
+
207
+ callback = kwargs.pop("callback", None)
208
+ callback_steps = kwargs.pop("callback_steps", None)
209
+
210
+ if callback is not None:
211
+ deprecate(
212
+ "callback",
213
+ "1.0.0",
214
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
215
+ )
216
+ if callback_steps is not None:
217
+ deprecate(
218
+ "callback_steps",
219
+ "1.0.0",
220
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
221
+ )
222
+
223
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
224
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
225
+
226
+ # 0. Default height and width to unet
227
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
228
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
229
+
230
+ # # 1. Check inputs
231
+ # self.check_inputs(
232
+ # prompt,
233
+ # image,
234
+ # mask_image,
235
+ # height,
236
+ # width,
237
+ # strength,
238
+ # callback_steps,
239
+ # output_type,
240
+ # negative_prompt,
241
+ # prompt_embeds,
242
+ # negative_prompt_embeds,
243
+ # ip_adapter_image,
244
+ # ip_adapter_image_embeds,
245
+ # callback_on_step_end_tensor_inputs,
246
+ # padding_mask_crop,
247
+ # )
248
+
249
+ self._guidance_scale = guidance_scale
250
+ self._clip_skip = clip_skip
251
+ self._cross_attention_kwargs = cross_attention_kwargs
252
+ self._interrupt = False
253
+
254
+ batch_size=6
255
+
256
+ if global_prompt is None and per_face_prompts is None:
257
+ raise ValueError(
258
+ "Provide either `global prompt` or `per_face_prompts`, or both. Can't leave them both empty"
259
+ )
260
+
261
+ prompt=[]
262
+
263
+ # 2. Define call parameters
264
+ if global_prompt is not None and per_face_prompts is None:
265
+ prompt=[f"{global_prompt},{DEFAULT_VIEW_PROMPT[f]}" for f in FACES]
266
+ elif global_prompt is not None and per_face_prompts is not None:
267
+ for f in FACES:
268
+ if f in per_face_prompts.keys():
269
+ prompt_text=f"{global_prompt},{per_face_prompts[f]}"
270
+ prompt.append(prompt_text)
271
+ else:
272
+ prompt.append(f"{global_prompt},{DEFAULT_VIEW_PROMPT[f]}")
273
+
274
+ negative_prompt_list = None
275
+ if negative_prompt is not None:
276
+ if isinstance(negative_prompt, list):
277
+ if len(negative_prompt) == batch_size:
278
+ negative_prompt_list = negative_prompt
279
+ else:
280
+ raise ValueError(
281
+ f"Length of `negative_prompt` list ({len(negative_prompt)}) does not match `batch_size` ({batch_size})."
282
+ )
283
+ else:
284
+ negative_prompt_list = [negative_prompt for _ in range(batch_size)]
285
+
286
+
287
+ device = self._execution_device
288
+
289
+ # 3. Encode input prompt
290
+ text_encoder_lora_scale = (
291
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
292
+ )
293
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
294
+ prompt,
295
+ device,
296
+ num_images_per_prompt,
297
+ self.do_classifier_free_guidance,
298
+ negative_prompt=negative_prompt_list,
299
+ lora_scale=text_encoder_lora_scale,
300
+ clip_skip=self.clip_skip,
301
+ )
302
+
303
+ # For classifier free guidance, we need to do two forward passes.
304
+ # Here we concatenate the unconditional and text embeddings into a single batch
305
+ # to avoid doing two forward passes
306
+ if self.do_classifier_free_guidance:
307
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
308
+
309
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
310
+ image_embeds = self.prepare_ip_adapter_image_embeds(
311
+ ip_adapter_image,
312
+ ip_adapter_image_embeds,
313
+ device,
314
+ batch_size * num_images_per_prompt,
315
+ self.do_classifier_free_guidance,
316
+ )
317
+
318
+ # 4. set timesteps
319
+ timesteps, num_inference_steps = retrieve_timesteps(
320
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
321
+ )
322
+ timesteps, num_inference_steps = self.get_timesteps(
323
+ num_inference_steps=num_inference_steps, strength=strength, device=device
324
+ )
325
+ # check that number of inference steps is not < 1 - as this doesn't make sense
326
+ if num_inference_steps < 1:
327
+ raise ValueError(
328
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
329
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
330
+ )
331
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
332
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
333
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
334
+ is_strength_max = strength == 1.0
335
+
336
+ # 5. Preprocess mask and image
337
+
338
+ # if padding_mask_crop is not None:
339
+ # crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
340
+ # resize_mode = "fill"
341
+ # else:
342
+ # crops_coords = None
343
+ # resize_mode = "default"
344
+
345
+ original_image = image
346
+ cond_image=self.image_processor.preprocess(image,height=height,width=width)
347
+ cond_image=cond_image.to(dtype=torch.float32)
348
+
349
+ empty_face=torch.zeros_like(cond_image,dtype=torch.float32)
350
+ init_image = [cond_image]+[empty_face for _ in range(5)]
351
+
352
+ #stack the condition image with empty faces to make a tensor stack of shape (6,3,H,W)
353
+ init_image=torch.stack(init_image,dim=0)
354
+
355
+
356
+ # 6. Prepare latent variables
357
+ num_channels_latents = self.vae.config.latent_channels
358
+ num_channels_unet = self.unet.config.in_channels
359
+ return_image_latents = num_channels_unet == 4
360
+
361
+ latents_outputs = self.prepare_latents(
362
+ batch_size,
363
+ num_channels_latents,
364
+ height,
365
+ width,
366
+ prompt_embeds.dtype,
367
+ device,
368
+ generator,
369
+ latents,
370
+ timestep=latent_timestep,
371
+ is_strength_max=is_strength_max,
372
+ return_noise=True,
373
+ return_image_latents=False
374
+ )
375
+
376
+ if return_image_latents:
377
+ latents, noise, image_latents = latents_outputs
378
+ else:
379
+ latents, noise = latents_outputs
380
+
381
+ mask_condition=self.prepare_cubemap_mask_condition(width,height)
382
+ masked_image=init_image
383
+
384
+ mask, masked_image_latents = self.prepare_mask_latents(
385
+ mask_condition,
386
+ masked_image,
387
+ batch_size,
388
+ height,
389
+ width,
390
+ prompt_embeds.dtype,
391
+ device,
392
+ generator,
393
+ self.do_classifier_free_guidance,
394
+ )
395
+
396
+ uv_maps=generate_cubemap_uv(height//self.vae_scale_factor,width//self.vae_scale_factor)
397
+
398
+ uv_channels=torch.stack([uv_maps[face] for face in FACES]).to(dtype=prompt_embeds.dtype,device=device)
399
+ uv_channel_input= torch.cat([uv_channels] * 2) if self.do_classifier_free_guidance else uv_channels
400
+
401
+ # # 8. Check that sizes of mask, masked image and latents match
402
+ # if num_channels_unet == 9:
403
+ # # default case for runwayml/stable-diffusion-inpainting
404
+ # num_channels_mask = mask.shape[1]
405
+ # num_channels_masked_image = masked_image_latents.shape[1]
406
+ # if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
407
+ # raise ValueError(
408
+ # f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
409
+ # f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
410
+ # f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
411
+ # f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
412
+ # " `pipeline.unet` or your `mask_image` or `image` input."
413
+ # )
414
+ # elif num_channels_unet != 4:
415
+ # raise ValueError(
416
+ # f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
417
+ # )
418
+
419
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
420
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
421
+
422
+ # 9.1 Add image embeds for IP-Adapter
423
+ added_cond_kwargs = (
424
+ {"image_embeds": image_embeds}
425
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
426
+ else None
427
+ )
428
+
429
+ # 9.2 Optionally get Guidance Scale Embedding
430
+ timestep_cond = None
431
+ if self.unet.config.time_cond_proj_dim is not None:
432
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
433
+ timestep_cond = self.get_guidance_scale_embedding(
434
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
435
+ ).to(device=device, dtype=latents.dtype)
436
+
437
+ # 10. Denoising loop
438
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
439
+ self._num_timesteps = len(timesteps)
440
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
441
+ for i, t in enumerate(timesteps):
442
+ if self.interrupt:
443
+ continue
444
+
445
+ # expand the latents if we are doing classifier free guidance
446
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
447
+
448
+ # concat latents, mask, masked_image_latents in the channel dimension
449
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
450
+
451
+ #print("Unet Channels:", num_channels_unet)
452
+
453
+ if num_channels_unet == 11:
454
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents,uv_channel_input], dim=1)
455
+ #print("Latent Input Shape:",latent_model_input.shape)
456
+
457
+ # predict the noise residual
458
+ noise_pred = self.unet(
459
+ latent_model_input,
460
+ t,
461
+ encoder_hidden_states=prompt_embeds,
462
+ timestep_cond=timestep_cond,
463
+ cross_attention_kwargs=self.cross_attention_kwargs,
464
+ added_cond_kwargs=added_cond_kwargs,
465
+ return_dict=False,
466
+ )[0]
467
+
468
+ # perform guidance
469
+ if self.do_classifier_free_guidance:
470
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
471
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
472
+
473
+ # compute the previous noisy sample x_t -> x_t-1
474
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
475
+ if num_channels_unet == 4:
476
+ init_latents_proper = image_latents
477
+ if self.do_classifier_free_guidance:
478
+ init_mask, _ = mask.chunk(2)
479
+ else:
480
+ init_mask = mask
481
+
482
+ if i < len(timesteps) - 1:
483
+ noise_timestep = timesteps[i + 1]
484
+ init_latents_proper = self.scheduler.add_noise(
485
+ init_latents_proper, noise, torch.tensor([noise_timestep])
486
+ )
487
+
488
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
489
+
490
+ if callback_on_step_end is not None:
491
+ callback_kwargs = {}
492
+ for k in callback_on_step_end_tensor_inputs:
493
+ callback_kwargs[k] = locals()[k]
494
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
495
+
496
+ latents = callback_outputs.pop("latents", latents)
497
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
498
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
499
+ mask = callback_outputs.pop("mask", mask)
500
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
501
+
502
+ # call the callback, if provided
503
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
504
+ progress_bar.update()
505
+ if callback is not None and i % callback_steps == 0:
506
+ step_idx = i // getattr(self.scheduler, "order", 1)
507
+ callback(step_idx, t, latents)
508
+
509
+ if not output_type == "latent":
510
+ condition_kwargs = {}
511
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
512
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
513
+ init_image_condition = init_image.clone()
514
+ init_image = self._encode_vae_image(init_image, generator=generator)
515
+ mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
516
+ condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
517
+ image = self.vae.decode(
518
+ latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
519
+ )[0]
520
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
521
+
522
+
523
+ else:
524
+ image = latents
525
+ has_nsfw_concept = None
526
+
527
+ if has_nsfw_concept is None:
528
+ do_denormalize = [True] * image.shape[0]
529
+ else:
530
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
531
+
532
+
533
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
534
+
535
+
536
+ input_img=self.image_processor.postprocess(cond_image,output_type=output_type,do_denormalize=do_denormalize)[0]
537
+ image[0]=input_img
538
+
539
+ # if padding_mask_crop is not None:
540
+ # image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
541
+
542
+ # Offload all models
543
+ self.maybe_free_model_hooks()
544
+
545
+ if not return_dict:
546
+ return (image, has_nsfw_concept)
547
+
548
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
549
+
550
+
scripts/cubemap_unet.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from diffusers import UNet2DConditionModel
4
+ from diffusers.models.attention import BasicTransformerBlock
5
+ from .cubemap_unet_attention_processor import InflatedAttentionProcessor
6
+
7
+ class CubemapUNet(UNet2DConditionModel):
8
+
9
+ def __init__(self, pretrained_unet, in_channels=11):
10
+ """
11
+ 1. **先加载 `pretrained_unet` 的默认通道数**
12
+ 2. **完全加载 UNet 预训练权重**
13
+ 3. **扩展 `conv_in` 和 `conv_out`**
14
+ """
15
+
16
+ # 从 `pretrained_unet.config` 复制所有参数,并修改 `in_channels`
17
+ unet_config = {**pretrained_unet.config} # 复制字典,防止修改原模型
18
+
19
+ super().__init__(**unet_config) # 这里直接传入所有参数
20
+
21
+
22
+ # Step 2: 完全加载 `pretrained_unet` 的权重
23
+ print("开始加载预训练权重")
24
+ self.load_state_dict(pretrained_unet.state_dict(), strict=True)
25
+ print("✅ UNet 预训练权重加载成功!")
26
+
27
+ # Step 3: **扩展 `conv_in`通道为11
28
+ self._expand_conv_in(pretrained_unet, in_channels)
29
+ self.register_to_config(in_channels=in_channels)
30
+
31
+ cubemap_attn_processor=InflatedAttentionProcessor()
32
+
33
+ self._modify_attn_processor(cubemap_attn_processor)
34
+
35
+
36
+
37
+ def _expand_conv_in(self, pretrained_unet, new_in_channels):
38
+ """扩展 `conv_in` 以适应新的输入通道"""
39
+ old_conv_in = pretrained_unet.conv_in
40
+ old_weight = old_conv_in.weight # [out_channels, 4, kernel, kernel]
41
+
42
+ # 创建新的 `conv_in`
43
+ new_conv_in = nn.Conv2d(
44
+ new_in_channels,
45
+ old_conv_in.out_channels,
46
+ kernel_size=old_conv_in.kernel_size,
47
+ stride=old_conv_in.stride,
48
+ padding=old_conv_in.padding
49
+ )
50
+
51
+ old_channels=old_weight.shape[1]
52
+
53
+ # 复制前 `4` 个通道的权重
54
+ new_weight = torch.zeros((new_conv_in.out_channels, new_in_channels, *old_conv_in.kernel_size))
55
+ new_weight[:, :old_channels, :, :] = old_weight # 复制 4 通道
56
+
57
+ # 随机初始化新增通道
58
+ new_weight[:, old_channels:, :, :] = torch.randn_like(new_weight[:, old_channels:, :, :]) * 0.01
59
+
60
+ new_conv_in.weight = nn.Parameter(new_weight)
61
+ if old_conv_in.bias is not None:
62
+ new_conv_in.bias = nn.Parameter(old_conv_in.bias.clone())
63
+
64
+ self.conv_in = new_conv_in
65
+ print(f"✅ `conv_in` 扩展成功!新输入通道: {new_in_channels}")
66
+
67
+ def _modify_attn_processor(self,processor):
68
+
69
+ for name, module in self.named_modules():
70
+ if isinstance(module, BasicTransformerBlock):
71
+ if hasattr(module, 'attn1'):
72
+ module.attn1.set_processor(processor)
73
+
74
+ if hasattr(module, 'attn2'):
75
+ module.attn2.set_processor(processor)
76
+
77
+
78
+
79
+
scripts/cubemap_unet_attention_processor.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ from diffusers.models.attention import BasicTransformerBlock
5
+ from diffusers.models.attention_processor import Attention, AttnProcessor
6
+ from flash_attn.flash_attn_interface import flash_attn_func
7
+
8
+ class InflatedAttentionProcessor(AttnProcessor):
9
+ def __call__(
10
+ self,
11
+ attn: Attention,
12
+ hidden_states: torch.Tensor,
13
+ encoder_hidden_states: Optional[torch.Tensor] = None,
14
+ attention_mask: Optional[torch.Tensor] = None,
15
+ temb: Optional[torch.Tensor] = None,
16
+ num_views: int = 6, # 例如 CubeMap 有 6 个视角
17
+ *args,
18
+ **kwargs,
19
+ ) -> torch.Tensor:
20
+ """
21
+ 实现 CubeDiff 论文中的 Inflated Attention:
22
+ - 将输入 `B, N, C` 转换为 `B, F*N, C`
23
+ - 在 `F*N` 维度上进行 Self-Attention
24
+ """
25
+ residual = hidden_states
26
+
27
+ # 1️⃣ 预处理
28
+ if attn.spatial_norm is not None:
29
+ hidden_states = attn.spatial_norm(hidden_states, temb)
30
+
31
+ input_ndim = hidden_states.ndim
32
+
33
+ if input_ndim == 4:
34
+ batch_size, channel, height, width = hidden_states.shape
35
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
36
+
37
+
38
+ BXF, N, C = hidden_states.shape # 原始注意力输入
39
+ # 2️⃣ **变换 `B, N, C → B, F*N, C`**
40
+ F = num_views
41
+ B=BXF//F
42
+
43
+
44
+ # 3️⃣ **标准 Attention 计算**
45
+ attention_mask = attn.prepare_attention_mask(attention_mask, hidden_states.shape[1], B)
46
+
47
+ if attn.group_norm is not None:
48
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
49
+
50
+ is_self_attn=False
51
+
52
+ if encoder_hidden_states is None:
53
+ hidden_states = hidden_states.view(B, F, N, C)
54
+ hidden_states = hidden_states.reshape(B, F * N, C)
55
+ encoder_hidden_states = hidden_states
56
+ is_self_attn=True
57
+ elif attn.norm_cross:
58
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
59
+
60
+ query = attn.to_q(hidden_states)
61
+ key = attn.to_k(encoder_hidden_states)
62
+ value = attn.to_v(encoder_hidden_states)
63
+
64
+ query = attn.head_to_batch_dim(query,out_dim=4).permute(0,2,1,3)
65
+ key = attn.head_to_batch_dim(key,out_dim=4).permute(0,2,1,3)
66
+ value = attn.head_to_batch_dim(value,out_dim=4).permute(0,2,1,3)
67
+
68
+ hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
69
+ B,L,H,D=hidden_states.shape
70
+ hidden_states = hidden_states.view(B,L,H*D)
71
+
72
+ # query = attn.head_to_batch_dim(query)
73
+ # key = attn.head_to_batch_dim(key)
74
+ # value = attn.head_to_batch_dim(value)
75
+
76
+ # attention_probs = attn.get_attention_scores(query, key, attention_mask)
77
+ # hidden_states = torch.bmm(attention_probs, value)
78
+ # hidden_states = attn.batch_to_head_dim(hidden_states)
79
+
80
+ # 4️⃣ **线性投影 & Dropout**
81
+ hidden_states = attn.to_out[0](hidden_states)
82
+ hidden_states = attn.to_out[1](hidden_states)
83
+
84
+ if is_self_attn:
85
+ # 5️⃣ **还原形状 `B, F*N, C → B, N, C`**
86
+ hidden_states = hidden_states.view(B, F, N, C)
87
+ hidden_states = hidden_states.reshape(BXF, N, C)
88
+
89
+ if input_ndim == 4:
90
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
91
+
92
+ if attn.residual_connection:
93
+ hidden_states = hidden_states + residual
94
+
95
+ hidden_states = hidden_states / attn.rescale_output_factor
96
+
97
+ return hidden_states
98
+
99
+
scripts/cubemap_vae.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models import AutoencoderKL
3
+ from torch import nn
4
+ from torchvision.transforms import ToPILImage
5
+ from torch import Tensor
6
+
7
+ class SynchronizedGroupNorm(nn.Module):
8
+ def __init__(self, original_group_norm: nn.GroupNorm, num_views: int = 6):
9
+ super().__init__()
10
+ self.num_views = num_views
11
+
12
+ # 继承原始分组参数
13
+ self.num_groups = original_group_norm.num_groups
14
+ self.num_channels = original_group_norm.num_channels
15
+ self.eps = original_group_norm.eps
16
+
17
+ # 按照通道组重构参数
18
+ self.group_size = self.num_channels // self.num_groups
19
+
20
+ # 继承原始参数(保持每个分组的仿射变换)
21
+ self.weight = nn.Parameter(original_group_norm.weight.detach().view(self.num_groups, self.group_size))
22
+ self.bias = nn.Parameter(original_group_norm.bias.detach().view(self.num_groups, self.group_size))
23
+
24
+ def forward(self, x: torch.Tensor):
25
+ """ 兼容 3D (B, C, D) 和 4D (B, C, H, W) 输入 """
26
+
27
+ #print(f"Input shape: {x.shape}") # Debugging
28
+
29
+ # 获取输入的形状信息
30
+ BxT, C = x.shape[:2] # 只取前两个维度
31
+ B = BxT // self.num_views # 计算 batch 维度
32
+
33
+ # 处理 3D (B, C, D) 输入
34
+ if x.dim() == 3:
35
+ D = x.shape[2]
36
+ x = x.view(B, self.num_views, self.num_groups, self.group_size, D)
37
+
38
+ # 计算 GroupNorm
39
+ mean = x.mean(dim=(1, 3, 4), keepdim=True)
40
+ var = x.var(dim=(1, 3, 4), keepdim=True, unbiased=False)
41
+ x = (x - mean) / torch.sqrt(var + self.eps)
42
+ # **修正 weight 和 bias 的形状**
43
+ weight = self.weight.view(1, self.num_groups, self.group_size, 1)
44
+ bias = self.bias.view(1, self.num_groups, self.group_size, 1)
45
+ x = x * weight + bias
46
+
47
+ # 还原形状
48
+ return x.view(BxT, C, D)
49
+
50
+ # 处理 4D (B, C, H, W) 输入
51
+ elif x.dim() == 4:
52
+ H, W = x.shape[2:]
53
+ x = x.view(B, self.num_views, self.num_groups, self.group_size, H, W)
54
+
55
+ # 计算 GroupNorm
56
+ mean = x.mean(dim=(1, 3, 4, 5), keepdim=True)
57
+ var = x.var(dim=(1, 3, 4, 5), keepdim=True, unbiased=False)
58
+ x = (x - mean) / torch.sqrt(var + self.eps)
59
+ # **修正 weight 和 bias 的形状**
60
+ weight = self.weight.view(1, self.num_groups, self.group_size, 1, 1)
61
+ bias = self.bias.view(1, self.num_groups, self.group_size, 1, 1)
62
+ x = x * weight + bias
63
+
64
+ # 还原形状
65
+ return x.view(BxT, C, H, W)
66
+
67
+ else:
68
+ raise ValueError(f"Unsupported input shape: {x.shape}, expected 3D (B, C, D) or 4D (B, C, H, W).")
69
+
70
+
71
+ class CubemapVAE(AutoencoderKL):
72
+ def __init__(self, pretrained_vae, num_views=6, in_channels=3,image_size=512):
73
+ super().__init__( # 继承自 AutoencoderKL
74
+ act_fn="silu",
75
+ block_out_channels=[128, 256, 512, 512],
76
+ down_block_types=[
77
+ "DownEncoderBlock2D",
78
+ "DownEncoderBlock2D",
79
+ "DownEncoderBlock2D",
80
+ "DownEncoderBlock2D"
81
+ ],
82
+ up_block_types=[
83
+ "UpDecoderBlock2D",
84
+ "UpDecoderBlock2D",
85
+ "UpDecoderBlock2D",
86
+ "UpDecoderBlock2D"
87
+ ],
88
+ latent_channels=pretrained_vae.config.latent_channels,
89
+ in_channels=in_channels,
90
+ out_channels=in_channels
91
+ )
92
+ self.num_views = num_views
93
+ self.in_channels = in_channels
94
+
95
+
96
+ # --- 替换关键模块,适配 Cubemap ---
97
+ # 原 AutoencoderKL 的编码器不够灵活,直接覆盖编码器
98
+ #self.encoder = CubemapEncoder(pretrained_encoder=pretrained_vae.encoder,num_views=num_views, in_channels=in_channels)
99
+ #self.decoder = CubemapDecoder(pretrained_decoder=pretrained_vae.decoder, num_views=num_views, out_channels=in_channels,in_channels=4)
100
+ self.encoder=pretrained_vae.encoder
101
+ self.decoder=pretrained_vae.decoder
102
+ self.quant_conv=pretrained_vae.quant_conv
103
+ self.post_quant_conv=pretrained_vae.post_quant_conv
104
+ # 将原 GroupNorm 替换为同步 GroupNorm
105
+ replace_group_norm_with_sgn(self, num_views=num_views)
106
+
107
+ def encode(self, images,return_dict:bool=True):
108
+ batch_size, num_views, num_channels, height, width = images.shape
109
+ images = images.view(batch_size*num_views,num_channels, height, width)
110
+ return super().encode(images,return_dict=return_dict)
111
+
112
+
113
+ def decode(self, latents, return_dict=True, **kwargs):
114
+ """
115
+ 自定义 VAE 解码:
116
+ - 去掉 UV 通道 (只保留前 4 个 latent 通道)
117
+ - 调用原始 VAE 解码流程
118
+ """
119
+
120
+ print("Decoder Recieve Latent Shape:", latents.shape)
121
+ # 确�� latents 至少有 4 个通道
122
+ if latents.shape[1] > 4:
123
+ latents = latents[:, :4, :, :] # 只保留前 4 个通道,去掉 UV 通道
124
+
125
+
126
+
127
+ return super().decode(latents, return_dict=return_dict, **kwargs)
128
+
129
+ def decode_to_tensor(self, latents):
130
+ decoded = self.decode(latents).sample # (B*6, 3, H, W)
131
+
132
+ B = latents.shape[0] // 6
133
+ images = torch.split(decoded, B, dim=0) # 按 batch 拆分
134
+
135
+ return images # Tuple of 6 tensors
136
+
137
+ def decode_to_pil_images(self, latents:Tensor):
138
+ images = self.decode_to_tensor(latents) # 获取 6 张图
139
+ to_pil = ToPILImage()
140
+
141
+ return [to_pil(img[0].cpu().detach()) for img in images] # 转换为 PIL
142
+
143
+
144
+
145
+ def replace_group_norm_with_sgn(model, num_views):
146
+ """ 遍历 model,找到所有 GroupNorm 并替换成 SynchronizedGroupNorm """
147
+ replacements = [] # 先收集要替换的 module 名称
148
+ for name, module in model.named_modules():
149
+ if isinstance(module, nn.GroupNorm):
150
+ replacements.append(name)
151
+
152
+ for name in replacements:
153
+ parent_module, attr_name = get_parent_module(model, name)
154
+ setattr(parent_module, attr_name, SynchronizedGroupNorm(getattr(parent_module, attr_name), num_views))
155
+
156
+ def get_parent_module(model, module_name):
157
+ """ 获取 `module_name` 所在的上一级 module 和属性名称 """
158
+ names = module_name.split(".")
159
+ parent_module = model
160
+ for name in names[:-1]: # 遍历到倒数第二层
161
+ parent_module = getattr(parent_module, name)
162
+ return parent_module, names[-1]
163
+
164
+
165
+
166
+ def flatten_face_names(face_names):
167
+ flat_face_names = []
168
+ for item in face_names:
169
+ if isinstance(item, str): # 直接是字符串
170
+ flat_face_names.append(item)
171
+ elif isinstance(item, (list,tuple)): # 是列表,展开其中的字符串
172
+ flat_face_names.extend(item)
173
+ else:
174
+ raise ValueError(f"Unexpected type in face_names: {type(item)}")
175
+ return flat_face_names
176
+
177
+
scripts/utils.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import Tensor
4
+ import torchvision.transforms.functional as TF
5
+ import torch.nn.functional as F
6
+ import cv2
7
+ import py360convert
8
+ import argparse
9
+ import os
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+ from PIL import Image
13
+
14
+ # 6 视角定义(前、后、左、右、上、下)
15
+ FACES = ["front", "back", "left", "right", "top", "bottom"]
16
+
17
+ FACE_CAPTION_MAP={
18
+ "front": "caption_front",
19
+ "back": "caption_back",
20
+ "left": "caption_left",
21
+ "right": "caption_right",
22
+ "top": "caption_top",
23
+ "bottom": "caption_bottom"
24
+ }
25
+
26
+ FACE_KEYS_MAP = {
27
+ "front": "F",
28
+ "back": "B",
29
+ "left": "L",
30
+ "right": "R",
31
+ "top": "U",
32
+ "bottom": "D"
33
+ }
34
+
35
+
36
+ def load_cubemap_dict(cubemap_path_dict:dict):
37
+ """
38
+ 从字典中加载 Cubemap 图像
39
+ """
40
+ cubemap_dict = {}
41
+ for face, path in cubemap_path_dict.items():
42
+ image = cv2.imread(path)
43
+ if image is None:
44
+ print(f"❌ 读取失败: {path}")
45
+ continue
46
+ cubemap_dict[FACE_KEYS_MAP[face]] = image
47
+ return cubemap_dict
48
+
49
+
50
+ def convert_to_cubemap(image, size=512):
51
+ """
52
+ 使用 py360convert 将 equirectangular 全景图转换为 6 视角 Cubemap
53
+ """
54
+ cubemap_dict = py360convert.e2c(image, face_w=size, mode="bilinear", cube_format="dict")
55
+ return cubemap_dict
56
+
57
+ def to_cubemap_dict(images:list[NDArray]):
58
+ cubemap_dict={}
59
+ for i, face in enumerate(FACES):
60
+ key=FACE_KEYS_MAP[face]
61
+ cubemap_dict[key]=images[i]
62
+
63
+ return cubemap_dict
64
+
65
+ def convert_to_equirectangular(cubemap_dict, width=1024,height=512):
66
+ """
67
+ 使用 py360convert 将 6 视角 Cubemap 转换为 equirectangular 全景图
68
+ """
69
+ equirectangular_image = py360convert.c2e(cubemap_dict, w=width,h=height, mode="bilinear",cube_format="dict")
70
+
71
+ if equirectangular_image.dtype == np.float32:
72
+ equirectangular_image = np.clip(equirectangular_image * 255, 0, 255).astype(np.uint8)
73
+
74
+ return Image.fromarray(equirectangular_image)
75
+
76
+
77
+ def process_image_e2c(input_path, output_dir, size=512):
78
+ """
79
+ 读取 equirectangular 全景图,转换为 Cubemap 并保存 6 张单独的图像
80
+ """
81
+ image = cv2.imread(input_path)
82
+ if image is None:
83
+ print(f"❌ 读取失败: {input_path}")
84
+ return
85
+
86
+ os.makedirs(output_dir, exist_ok=True)
87
+
88
+ # 生成 Cubemap
89
+ cubemap_images = convert_to_cubemap(image, size)
90
+
91
+ print(cubemap_images.keys())
92
+
93
+ # 保存 6 张图像
94
+ for face in FACES:
95
+ output_path = os.path.join(output_dir, f"{face}.png")
96
+ face_key = FACE_KEYS_MAP[face]
97
+ cv2.imwrite(output_path, cubemap_images[face_key])
98
+ print(f"✅ {face} 视角已保存: {output_path}")
99
+
100
+ def process_image_c2e(cubemap_path_dict, output_path, width,height):
101
+ """
102
+ 读取 6 视角 Cubemap,转换为 equirectangular 全景图并保存
103
+ """
104
+ cubemap_dict=load_cubemap_dict(cubemap_path_dict)
105
+ # 生成 equirectangular 全景图
106
+ equirectangular_image = convert_to_equirectangular(cubemap_dict, width,height)
107
+ cv2.imwrite(output_path, equirectangular_image)
108
+ print(f"✅ 全景图已保存: {output_path}")
109
+
110
+
111
+
112
+ def perspective_transform_patch(patch: torch.Tensor, delta):
113
+ """
114
+ 对输入的 patch 使用 torchvision.transforms.functional.perspective 进行透视变换。
115
+
116
+ 参数:
117
+ patch: Tensor,形状 (C, H, W),图像 patch
118
+ offset: float,表示左右方向的偏移量(单位:像素),用于定义目标透视变换的 endpoints
119
+ 例如:正值表示上边向右平移,下边向左平移;负值则相反。
120
+
121
+ 返回:
122
+ transformed: 透视变换后的 patch,Tensor,形状与 patch 相同
123
+ """
124
+ C, H, W = patch.shape
125
+
126
+ # 定义原始四个角的坐标(顺序为:上左, 上右, 下右, 下左)
127
+ startpoints = [
128
+ [0, 0], # top-left
129
+ [W, 0], # top-right
130
+ [W, H], # bottom-right
131
+ [0, H] # bottom-left
132
+ ]
133
+
134
+ endpoints=[[sp_i + d_i for sp_i, d_i in zip(sp, d)] for sp, d in zip(startpoints, delta)]
135
+ # 注意:F.perspective 接受的 startpoints 和 endpoints 应为 List[List[float]]
136
+ # 透视变换支持直接传入 tensor,但这里直接使用 list 即可。
137
+ return TF.perspective(patch, startpoints, endpoints, interpolation=TF.InterpolationMode.BILINEAR)
138
+
139
+
140
+ def stretch_edge_patch(patch,pad_width,edge_key):
141
+ C,H,W=patch.shape
142
+ H_new=H + 2*pad_width
143
+ W_new=W + 2*pad_width
144
+ if edge_key=="top":
145
+ top_edge=TF.resize(patch,(pad_width,W_new))
146
+ delta_top = [
147
+ [0, 0], # top-left
148
+ [0, 0], # top-right
149
+ [-pad_width , 0], # bottom-right
150
+ [pad_width , 0] # bottom-left
151
+ ]
152
+ return perspective_transform_patch(top_edge, delta_top)
153
+
154
+ elif edge_key=="bottom":
155
+ bottom_edge=TF.resize(patch,(pad_width,W_new))
156
+ delta_bottom=[
157
+ [pad_width, 0], # top-left
158
+ [-pad_width, 0], # top-right
159
+ [0, 0], # bottom-right
160
+ [0 , 0] # bottom-left
161
+ ]
162
+ return perspective_transform_patch(bottom_edge,delta_bottom)
163
+ elif edge_key=="left":
164
+ left_edge=TF.resize(patch,(H_new,pad_width))
165
+ delta_left=[
166
+ [0, 0], # top-left 变为 (offset, 0)
167
+ [0, pad_width], # top-right 变为 (W+offset, 0)
168
+ [0, -pad_width], # bottom-right 变为 (W-offset, H)
169
+ [0 , 0] # bottom-left 变为 (-offset, H)
170
+ ]
171
+ return perspective_transform_patch(left_edge,delta_left)
172
+ elif edge_key=="right":
173
+ right_edge=TF.resize(patch,(H_new,pad_width))
174
+ delta_right=[
175
+ [0, pad_width], # top-left 变为 (offset, 0)
176
+ [0, 0], # top-right 变为 (W+offset, 0)
177
+ [0, 0], # bottom-right 变为 (W-offset, H)
178
+ [0 , -pad_width] # bottom-left 变为 (-offset, H)
179
+ ]
180
+ return perspective_transform_patch(right_edge,delta_right)
181
+
182
+ # -------------------------------
183
+ # 定义各面拼接函数
184
+ # -------------------------------
185
+
186
+ def pad_front(face:Tensor, faces, pad_width):
187
+ """
188
+ 对前视图进行边缘拼接:
189
+ 上侧:拼接上视图的下边缘 (取 top[:, -w:, :])
190
+ 下侧:拼接下视图的上边缘 (取 bottom[:, :w, :])
191
+ 左侧:拼接左视图的右边缘 (取 left[:, :, -w:])
192
+ 右侧:拼接右视图的左边缘 (取 right[:, :, :w])
193
+ """
194
+ C, H, W = face.shape
195
+ H_new=H + 2*pad_width
196
+ W_new=W + 2*pad_width
197
+ padded = torch.zeros((C, H_new, W_new), dtype=face.dtype, device=face.device)
198
+ # 中间放置前视图
199
+ padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
200
+ # 上边缘
201
+ top_edge=faces['top'][:, -pad_width:, :]
202
+ padded[:, 0:pad_width, 0:W+2*pad_width] += stretch_edge_patch(top_edge,pad_width,"top")
203
+
204
+ # 下边缘
205
+ bottom_edge=faces['bottom'][:, :pad_width, :]
206
+ padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
207
+
208
+ # 左边缘
209
+ left_edge=faces['left'][:, :, -pad_width:]
210
+ padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
211
+
212
+ # 右边缘
213
+ right_edge=faces['right'][:, :, :pad_width]
214
+ padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
215
+ return padded
216
+
217
+ def pad_right(face, faces, pad_width):
218
+ """
219
+ 对右视图进行边缘拼接:
220
+ 左侧:拼接前视图的右边缘 (front[:, :, -w:])
221
+ 右侧:拼接后视图的左边缘 (back[:, :, :w])
222
+ 上侧:拼接上视图的右边缘 (top[:, :, -w:])
223
+ 下侧:拼接下视图的右边缘 (bottom[:, :, -w:])
224
+ """
225
+ C, H, W = face.shape
226
+ H_new=H + 2*pad_width
227
+ W_new=W + 2*pad_width
228
+ padded = torch.zeros((C, H_new, W_new), dtype=face.dtype, device=face.device)
229
+ padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
230
+
231
+ left_edge=faces['front'][:, :, -pad_width:]
232
+ padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
233
+
234
+ right_edge=faces['back'][:, :, :pad_width]
235
+ padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
236
+
237
+ # 上侧:拼接上视图的右边缘,顺时针旋转90度
238
+ # 原始 top 边缘为 shape (C, H, w) ,旋转后变为 (C, w, H)
239
+ top_edge = torch.rot90(faces['top'][:, :, -pad_width:], k=3, dims=(1,2))
240
+
241
+ padded[:, 0:pad_width, :] += stretch_edge_patch(top_edge,pad_width,"top")
242
+
243
+ # 下侧:拼接下视图的右边缘,逆时针旋转90度
244
+ # 原始 bottom 边缘为 shape (C, H, w) ,旋转后变为 (C, w, H)
245
+ bottom_edge = torch.rot90(faces['bottom'][:, :, -pad_width:], k=1, dims=(1,2))
246
+ padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
247
+ return padded
248
+
249
+ def pad_back(face, faces, pad_width):
250
+ """
251
+ 对后视图进行边缘拼接:
252
+ 左侧:拼接右视图的右边缘 (right[:, :, -w:])
253
+ 右侧:拼接左视图的左边缘 (left[:, :, :w])
254
+ 上侧:拼接上视图的上边缘 (top[:, :w, :])
255
+ 下侧:拼接下视图的下边缘 (bottom[:, -w:, :])
256
+ """
257
+ C, H, W = face.shape
258
+ H_new=H + 2*pad_width
259
+ W_new=W + 2*pad_width
260
+ padded = torch.zeros((C, H_new, W_new), dtype=face.dtype, device=face.device)
261
+ padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
262
+
263
+ left_edge=faces['right'][:, :, -pad_width:]
264
+ padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
265
+
266
+ right_edge=faces['left'][:, :, :pad_width]
267
+ padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
268
+
269
+ # 上侧:���用上视图的上边缘,并旋转180度
270
+ # 旋转180度可使用 torch.rot90(..., k=2, dims=(1,2))
271
+ top_edge = torch.rot90(faces['top'][:, :pad_width, :], k=2, dims=(1,2))
272
+ padded[:, 0:pad_width, :] +=stretch_edge_patch(top_edge,pad_width,"top")
273
+
274
+ # 下侧:使用下视图的下边缘,并旋转180度
275
+ bottom_edge = torch.rot90(faces['bottom'][:, -pad_width:, :], k=2, dims=(1,2))
276
+ padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
277
+ return padded
278
+
279
+ def pad_left(face, faces, pad_width):
280
+ """
281
+ 对左视图进行边缘拼接:
282
+ 左侧:拼接后视图的右边缘 (back[:, :, -w:])
283
+ 右侧:拼接前视图的左边缘 (front[:, :, :w])
284
+ 上侧:拼接上视图的左边缘 (top[:, :, :w])
285
+ 下侧:拼接下视图的左边缘 (bottom[:, :, :w])
286
+ """
287
+ C, H, W = face.shape
288
+ padded = torch.zeros((C, H + 2*pad_width, W + 2*pad_width), dtype=face.dtype, device=face.device)
289
+ padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
290
+
291
+ left_edge=faces['back'][:, :, -pad_width:]
292
+ padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
293
+
294
+ right_edge=faces['front'][:, :, :pad_width]
295
+ padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
296
+
297
+ top_edge=torch.rot90(faces['top'][:, :, :pad_width],k=1,dims=(1,2))
298
+ padded[:, 0:pad_width, :] += stretch_edge_patch(top_edge,pad_width,"top")
299
+
300
+ bottom_edge=torch.rot90(faces['bottom'][:, :, :pad_width],k=3,dims=(1,2))
301
+ padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
302
+ return padded
303
+
304
+ def pad_top(face, faces, pad_width):
305
+ """
306
+ 对上视图进行边缘拼接:
307
+ 下侧:拼接前视图的上边缘 (front[:, :w, :])
308
+ 左侧:拼接左视图的上边缘 (left[:, :w, :])
309
+ 右侧:拼接右视图的上边缘 (right[:, :w, :])
310
+ 上侧:拼接后视图的上边缘 (back[:, :w, :])
311
+ """
312
+ C, H, W = face.shape
313
+ padded = torch.zeros((C, H + 2*pad_width, W + 2*pad_width), dtype=face.dtype, device=face.device)
314
+ padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
315
+
316
+ bottom_edge=faces['front'][:, :pad_width, :]
317
+ padded[:, H+pad_width:H+2*pad_width, :] +=stretch_edge_patch(bottom_edge,pad_width,"bottom")
318
+
319
+ left_edge=torch.rot90(faces['left'][:, :pad_width, :],k=3,dims=(1,2))
320
+ padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
321
+
322
+ right_edge=torch.rot90(faces['right'][:, :pad_width, :],k=1,dims=(1,2))
323
+ padded[:, :, W+pad_width:W+2*pad_width]+=stretch_edge_patch(right_edge,pad_width,"right")
324
+
325
+ top_edge=torch.rot90(faces['back'][:, :pad_width, :], k=2, dims=(1,2))
326
+
327
+ padded[:, 0:pad_width, :] +=stretch_edge_patch(top_edge,pad_width,"top")
328
+ return padded
329
+
330
+ def pad_bottom(face, faces, pad_width):
331
+ """
332
+ 对下视图进行边缘拼接:
333
+ 上侧:拼接前视图的下边缘 (front[:, -w:, :])
334
+ 左侧:拼接左视图的下边缘 (left[:, -w:, :])
335
+ 右侧:拼接右视图的下边缘 (right[:, -w:, :])
336
+ 下侧:拼接后视图的下边缘 (back[:, :-w, :])
337
+ """
338
+ C, H, W = face.shape
339
+ padded = torch.zeros((C, H + 2*pad_width, W + 2*pad_width), dtype=face.dtype, device=face.device)
340
+ padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
341
+
342
+ top_edge=faces['front'][:, -pad_width:, :]
343
+ padded[:, 0:pad_width, :] += stretch_edge_patch(top_edge,pad_width,"top")
344
+
345
+ left_edge=torch.rot90(faces['left'][:, -pad_width:, :],k=1,dims=(1,2))
346
+ padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
347
+
348
+ right_edge=torch.rot90(faces['right'][:, -pad_width:, :],k=3,dims=(1,2))
349
+ padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
350
+
351
+ bottom_edge=torch.rot90(faces['back'][:, -pad_width:, :],k=2,dims=(1,2))
352
+ padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
353
+ return padded
354
+
355
+ pad_funcs = {
356
+ "front": pad_front,
357
+ "right": pad_right,
358
+ "back": pad_back,
359
+ "left": pad_left,
360
+ "top": pad_top,
361
+ "bottom": pad_bottom,
362
+ }
363
+
364
+ def pad_face(faces: dict, width: int, face_name: str)->Tensor:
365
+ """
366
+ 根据 face_name 调用对应的拼接函数
367
+ """
368
+
369
+ if face_name not in pad_funcs:
370
+ raise ValueError(f"Invalid face name: {face_name}. Must be one of {list(pad_funcs.keys())}.")
371
+ return pad_funcs[face_name](faces[face_name], faces, width)
372
+
373
+
374
+ def prepare_mask(image,facename):
375
+ """
376
+ 根据 facename 为每张图生成对应的 mask。
377
+ 如果 facename 为 "front",mask 全部置为 1,其它置为 0。
378
+ 生成的 mask 形状为 (1, H, W),即与图像的高度和宽度一致,但只有 1 个通道。
379
+
380
+ 参数:
381
+ image (torch.Tensor): 图像 tensor,形状应为 (C, H, W) 或者 (N, C, H, W) 中的单张图像
382
+ facename (str): 表示图像对应的面名称,例如 "front", "back" 等
383
+
384
+ 返回:
385
+ torch.Tensor: 生成的 mask,形状为 (1, H, W)
386
+ """
387
+ # 如果 image 是 (C, H, W),那么 H=image.shape[1], W=image.shape[2]
388
+ # 如果 image 是 (N, C, H, W),可以使用 image[0] 取得一张图像的尺寸
389
+ if image.ndim == 3:
390
+ H, W = image.shape[1], image.shape[2]
391
+ elif image.ndim == 4:
392
+ H, W = image.shape[2], image.shape[3]
393
+ else:
394
+ raise ValueError("Unsupported image shape")
395
+
396
+ mask_shape = (1, H, W)
397
+ if facename == "front":
398
+ return torch.zeros(mask_shape, dtype=image.dtype, device=image.device)
399
+ else:
400
+ return torch.ones(mask_shape, dtype=image.dtype, device=image.device)
401
+
402
+
403
+ def generate_cubemap_uv(H, W):
404
+ """ 生成 cube face 上每个点的 3D 归一化坐标 (x, y, z) 并计算 UV 映射 """
405
+
406
+ H=int(H)
407
+ W=int(W)
408
+
409
+ # 生成 [-1,1] 范围的 grid(cube face 上的 x, y 坐标)
410
+ u_range = torch.linspace(-1, 1, W).view(1, -1).expand(H, -1) # HxW
411
+ v_range = torch.linspace(-1, 1, H).view(-1, 1).expand(-1, W) # HxW
412
+
413
+ # 设定六个面 (x, y, z) 归一化坐标
414
+ faces = {
415
+ "front": (u_range, v_range, torch.ones_like(u_range)), # (x, y, z=1)
416
+ "back": (-u_range, v_range, -torch.ones_like(u_range)), # (-x, y, z=-1)
417
+ "left": (-torch.ones_like(u_range), v_range, u_range), # (-1, y, -x)
418
+ "right": (torch.ones_like(u_range), v_range, -u_range), # (1, y, x)
419
+ "top": (u_range, -torch.ones_like(u_range), v_range), # (x, 1, y)
420
+ "bottom": (u_range, torch.ones_like(u_range), -v_range), # (x, -1, -y)
421
+ }
422
+
423
+ # 计算六个面的 UV
424
+ uv_faces = {}
425
+ for face, (x, y, z) in faces.items():
426
+ u = torch.atan2(x, z)/(2*torch.pi)+0.5
427
+ v = torch.atan2(y, torch.sqrt(x ** 2 + z ** 2))/(2*torch.pi)+0.5
428
+ uv_faces[face] = torch.stack([u,v], dim=0) # shape: (2, H, W)
429
+
430
+ return uv_faces # 返回每个面的 UV 坐标
431
+
432
+ import torch
433
+
434
+ def generate_cubemap_uv_padding(H, W, padding_pixels=0):
435
+ """ 生成 cube face 上每个点的 3D 归一化坐标 (x, y, z) 并计算 UV 映射,支持自定义 padding """
436
+
437
+ H = int(H)
438
+ W = int(W)
439
+
440
+ # 计算 padding 的比例
441
+ padding_ratio = padding_pixels / W # 例如 50 / 512 ≈ 0.0977
442
+
443
+ # 计算扩展后的尺寸
444
+ H_new = H + 2 * padding_pixels
445
+ W_new = W + 2 * padding_pixels
446
+
447
+ # 生成扩展范围的 grid(从 [-1-padding_ratio, 1+padding_ratio])
448
+ u_range = torch.linspace(-1 - padding_ratio, 1 + padding_ratio, W_new).view(1, -1).expand(H_new, -1)
449
+ v_range = torch.linspace(-1 - padding_ratio, 1 + padding_ratio, H_new).view(-1, 1).expand(-1, W_new)
450
+
451
+ # 定义六个面的 3D 归一化坐标
452
+ faces = {
453
+ "front": (u_range, v_range, torch.ones_like(u_range)),
454
+ "back": (-u_range, v_range, -torch.ones_like(u_range)),
455
+ "left": (-torch.ones_like(u_range), v_range, u_range),
456
+ "right": (torch.ones_like(u_range), v_range, -u_range),
457
+ "top": (u_range, -torch.ones_like(u_range), v_range),
458
+ "bottom": (u_range, torch.ones_like(u_range), -v_range),
459
+ }
460
+
461
+ # 计算六个面的 UV
462
+ uv_faces = {}
463
+ for face, (x, y, z) in faces.items():
464
+ u = torch.atan2(x, z) / (2 * torch.pi) + 0.5
465
+ v = torch.atan2(y, torch.sqrt(x ** 2 + z ** 2)) / (2 * torch.pi) + 0.5
466
+ uv = torch.stack([u, v], dim=0) # shape: (2, H_new, W_new)
467
+ # 使用双线性插值将 UV resize 回 (2, H, W)
468
+ uv_resized = F.interpolate(uv.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True).squeeze(0)
469
+
470
+ uv_faces[face] = uv_resized
471
+
472
+ return uv_faces
473
+
474
+
475
+ def merge_uv_with_latent(latent, uv_maps,dim=1):
476
+ # 调整 uv_maps 的大小,使其与 latent 的空间尺寸一致
477
+ # 注意:这里采用双线性插值,并设置 align_corners=False
478
+ uv_maps_resized = F.interpolate(uv_maps, size=latent.shape[-2:], mode="bilinear", align_corners=False)
479
+
480
+ # 在通道维度上拼接,即 dim=1
481
+ latent_with_uv = torch.cat([latent, uv_maps_resized], dim=dim)
482
+ return latent_with_uv
483
+
484
+
485
+
486
+ def resize_and_crop(image: np.ndarray, padding: int) -> np.ndarray:
487
+ """
488
+ 先将输入的图片 resize 到 (H + padding * 2, W + padding * 2),
489
+ 然后再剪裁掉外侧四个边缘各 padding 宽度,恢复到原来的 H, W。
490
+
491
+ 参数:
492
+ image (np.ndarray): 输入的图片,形状为 (H, W, C)。
493
+ padding (int): 需要添加的边界宽度。
494
+
495
+ 返回:
496
+ np.ndarray: 处理后的图片,形状仍为 (H, W, C)。
497
+ """
498
+ if not isinstance(image, np.ndarray):
499
+ raise ValueError("输入图片必须是 numpy 数组格式")
500
+
501
+ H, W = image.shape[:2] # 获取原始尺寸
502
+
503
+ # Step 1: Resize 到 (H + padding * 2, W + padding * 2)
504
+ resized_image = cv2.resize(image, (W + 2 * padding, H + 2 * padding), interpolation=cv2.INTER_LINEAR)
505
+
506
+ # Step 2: 裁剪掉外�� padding 的宽度,恢复到原来的 (H, W)
507
+ cropped_image = resized_image[padding:H + padding, padding:W + padding]
508
+
509
+ return cropped_image # 返回 numpy.ndarray
510
+
511
+ def cubemap_unfold(cubemaps,H:int=512,W:int=512,channels:int=3,transparent:bool=False)->Image.Image:
512
+ # 拼接成 3x4 的布局
513
+ # 整体画布尺寸:3 行,每行 H 像素;4 列,每列 W 像素
514
+ canvas_H = 3 * H
515
+ canvas_W = 4 * W
516
+
517
+ num_channels=channels if transparent==False else channels+1
518
+ # 确保 canvas 也是正确的形状
519
+ canvas = np.zeros(shape=(canvas_H, canvas_W, num_channels), dtype=cubemaps[0].dtype)
520
+
521
+ if channels==1:
522
+ canvas=np.squeeze(canvas, axis=-1)
523
+
524
+
525
+ face_imgs = {face: cubemaps[i] for i, face in enumerate(FACES)}
526
+
527
+ alpha_layer=num_channels-1
528
+
529
+ # 布局安排(以 0 为起始索引):
530
+ # 第一行:只在 (0,1) 位置放 top
531
+ # 第二行:依次为 left, front, right, back(对应列 0,1,2,3)
532
+ # 第三行:只在 (2,1) 位置放 bottom
533
+
534
+ # 将 top 放在第一行第二列
535
+ row, col = 0, 1
536
+ if channels==1:
537
+ canvas[row*H:(row+1)*H, col*W:(col+1)*W,0] = face_imgs['top']
538
+ else:
539
+ canvas[row*H:(row+1)*H, col*W:(col+1)*W,:channels] = face_imgs['top']
540
+
541
+ if transparent:
542
+ canvas[row * H:(row + 1) * H, col * W:(col + 1) * W, alpha_layer] = 255 # Set alpha to opaque
543
+
544
+
545
+ # 将 left, front, right, back 分别放在第二行(行索引 1)从列 0 到 3
546
+ row = 1
547
+ for i, face in enumerate(['left', 'front', 'right', 'back']):
548
+ col = i # 分别放在第 0,1,2,3 列
549
+ if channels==1:
550
+ canvas[row*H:(row+1)*H, col*W:(col+1)*W,0] = face_imgs[face]
551
+ else:
552
+ canvas[row*H:(row+1)*H, col*W:(col+1)*W,:channels] = face_imgs[face]
553
+
554
+ if transparent:
555
+ canvas[row * H:(row + 1) * H, col * W:(col + 1) * W, alpha_layer] = 255
556
+
557
+ # 将 bottom 放在第三行第二列
558
+ row, col = 2, 1
559
+
560
+ if channels==1:
561
+ canvas[row*H:(row+1)*H, col*W:(col+1)*W,0] = face_imgs['bottom']
562
+ else:
563
+ canvas[row*H:(row+1)*H, col*W:(col+1)*W,:channels] = face_imgs['bottom']
564
+
565
+
566
+ if transparent:
567
+ canvas[row * H:(row + 1) * H, col * W:(col + 1) * W, alpha_layer] = 255 # Set alpha to opaque
568
+
569
+
570
+ if channels==1:
571
+ return Image.fromarray(canvas,mode="L")
572
+
573
+ if np.issubdtype(canvas.dtype, np.floating):
574
+ canvas = np.clip(canvas * 255, 0, 255).astype(np.uint8)
575
+
576
+ return Image.fromarray(canvas)
viewer.html ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <title>全景图查看器</title>
6
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/pannellum@2.5.6/build/pannellum.css"/>
7
+ <script type="text/javascript" src="https://cdn.jsdelivr.net/npm/pannellum@2.5.6/build/pannellum.js"></script>
8
+ <style>
9
+ html, body {
10
+ margin: 0;
11
+ padding: 0;
12
+ width: 100%;
13
+ height: 100%;
14
+ overflow: hidden;
15
+ }
16
+ #panorama {
17
+ width: 100%;
18
+ height: 100%;
19
+ }
20
+ .pnlm-load-button {
21
+ display: none !important;
22
+ }
23
+ </style>
24
+ </head>
25
+ <body>
26
+ <div id="panorama"></div>
27
+ <script>
28
+ // 监听来自父窗口的消息
29
+ window.addEventListener('message', function(event) {
30
+ if (event.data.type === 'loadPanorama') {
31
+ console.log('Received image URL:', event.data.image);
32
+
33
+ // 销毁现有的查看器(如果存在)
34
+ if (window.viewer) {
35
+ window.viewer.destroy();
36
+ }
37
+
38
+ // 创建新的查看器
39
+ window.viewer = pannellum.viewer('panorama', {
40
+ type: 'equirectangular',
41
+ panorama: event.data.image,
42
+ autoLoad: true,
43
+ autoRotate: -2,
44
+ compass: true,
45
+ northOffset: 0,
46
+ showFullscreenCtrl: true,
47
+ showControls: true,
48
+ mouseZoom: true,
49
+ draggable: true,
50
+ friction: 0.2,
51
+ minHfov: 50,
52
+ maxHfov: 120,
53
+ hfov: 100,
54
+ onLoad: function() {
55
+ console.log('Panorama loaded successfully');
56
+ },
57
+ onError: function(error) {
58
+ console.error('Error loading panorama:', error);
59
+ }
60
+ });
61
+ }
62
+ });
63
+ </script>
64
+ </body>
65
+ </html>