Spaces:
Sleeping
Sleeping
zimhe
commited on
Commit
·
a521a3f
1
Parent(s):
5790b69
add scripts and examples
Browse files- .gitattributes +2 -0
- examples/examples.json +65 -0
- scripts/__pycache__/analysis.cpython-312.pyc +0 -0
- scripts/__pycache__/cubemap_dataset.cpython-310.pyc +0 -0
- scripts/__pycache__/cubemap_dataset.cpython-312.pyc +0 -0
- scripts/__pycache__/cubemap_diffusion_pipeline.cpython-312.pyc +0 -0
- scripts/__pycache__/cubemap_unet.cpython-310.pyc +0 -0
- scripts/__pycache__/cubemap_unet.cpython-312.pyc +0 -0
- scripts/__pycache__/cubemap_unet_attention.cpython-310.pyc +0 -0
- scripts/__pycache__/cubemap_unet_attention.cpython-312.pyc +0 -0
- scripts/__pycache__/cubemap_unet_attention_processor.cpython-312.pyc +0 -0
- scripts/__pycache__/cubemap_unet_attention_simple.cpython-312.pyc +0 -0
- scripts/__pycache__/cubemap_vae.cpython-310.pyc +0 -0
- scripts/__pycache__/cubemap_vae.cpython-312.pyc +0 -0
- scripts/__pycache__/preprocess.cpython-310.pyc +0 -0
- scripts/__pycache__/train_cubemap_args.cpython-310.pyc +0 -0
- scripts/__pycache__/train_cubemap_args.cpython-312.pyc +0 -0
- scripts/__pycache__/train_cubemap_unet.cpython-312.pyc +0 -0
- scripts/__pycache__/utils.cpython-310.pyc +0 -0
- scripts/__pycache__/utils.cpython-312.pyc +0 -0
- scripts/cubemap_diffusion_pipeline.py +550 -0
- scripts/cubemap_unet.py +79 -0
- scripts/cubemap_unet_attention_processor.py +99 -0
- scripts/cubemap_vae.py +177 -0
- scripts/utils.py +576 -0
- viewer.html +65 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
examples/examples.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Mountain Fall": {
|
| 3 |
+
"img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/mount-fall.jpg",
|
| 4 |
+
"global": "large wide angle landscape view of rocky mountains at Fall, water lake and forest at the mountain root, blue sky at dusk with scattered golden clouds",
|
| 5 |
+
"front": "view of towering mountains and rivers, with dim clear foggy sky",
|
| 6 |
+
"back": "low mountains, scattered and layered, bright golden cloudy sky",
|
| 7 |
+
"left": "layered mountain peaks with thin cyan and green and creamy color,bright cloudy sky",
|
| 8 |
+
"right": "green water blending with the sky, bright cloudy sky, soft light",
|
| 9 |
+
"top": "bright beiged sky color, thick scattered clouds, faded to background",
|
| 10 |
+
"bottom": "earthy ground with dry grass and scattered stones"
|
| 11 |
+
},
|
| 12 |
+
"Qianli": {
|
| 13 |
+
"img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/qianli.png",
|
| 14 |
+
"global": "Chinese traditional landscape painting, aged beiged color scheme, thin and subtle warm colors, mountains with warm color bottom and with peaks tint thin cyan and green, bright cloudy sky with soft light highly detailed, masterpiece.",
|
| 15 |
+
"front": "view of towering mountains and rivers, with dim clear foggy sky",
|
| 16 |
+
"back": "low mountains, scattered and layered, bright golden cloudy sky",
|
| 17 |
+
"left": "layered mountain peaks with thin cyan and green and creamy color,bright cloudy sky",
|
| 18 |
+
"right": "green water blending with the sky, bright cloudy sky, soft light",
|
| 19 |
+
"top": "bright beiged sky color, thick scattered clouds, faded to background",
|
| 20 |
+
"bottom": "earthy ground with dry grass and scattered stones"
|
| 21 |
+
}
|
| 22 |
+
,
|
| 23 |
+
"Office":{
|
| 24 |
+
"img":"examples/office.png",
|
| 25 |
+
"global": "A modernist open office interior by Louis Khan, low view angle, simple clean structures, spacious and expansive, natural light, concrete and warm wood material, view of the city, high quality, 8k resolution, photorealistic, ultra-detailed, Masterpiece.",
|
| 26 |
+
"front": "office interior with delicate workplace desks",
|
| 27 |
+
"back": "office area with many desks and chairs",
|
| 28 |
+
"left": "large glass partition with dark metal frames",
|
| 29 |
+
"right": "public office area with beautiful curvy structures",
|
| 30 |
+
"top": "clean white ceiling, linear continuous lights",
|
| 31 |
+
"bottom": "clean empty floor with view from above, clean gray gloss material"
|
| 32 |
+
},
|
| 33 |
+
"Forest":{
|
| 34 |
+
"img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/forest.jpg",
|
| 35 |
+
"global": "within the forest with beautiful weaving trees, a tiger inside the bush, small water flowing, very thick tree trunks.",
|
| 36 |
+
"front": "view of thriving forest, with branches covering the sky",
|
| 37 |
+
"back": "a small pathway leading to a small water stream, with trees on both sides",
|
| 38 |
+
"left": "short bushes and trees,",
|
| 39 |
+
"right": "Stone-made pathways, with green leaves and branches, bushes surrounding the bottom",
|
| 40 |
+
"top": "crossing branches and leaves, with a few rays of light peeking through",
|
| 41 |
+
"bottom": "earthy ground with grass and scattered stones and twigs, with a few leaves on the ground"
|
| 42 |
+
},
|
| 43 |
+
|
| 44 |
+
"Las-Meninas":{
|
| 45 |
+
"img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/Las-meninas.png",
|
| 46 |
+
"global": "A 17th-century Spanish royal studio interior with baroque decor, soft natural light, a painter at work, a young princess in a white dress, maids of honor, and a calm dog. Classic, elegant, and atmospheric.",
|
| 47 |
+
"front": "The princess in the center with her maids, the painter beside a large canvas, and a dog lying on the floor. Bright, central focus.",
|
| 48 |
+
"back": "Back of the studio with walls full of paintings and dim shadows. Quiet and deep space.",
|
| 49 |
+
"left": "Side view of the large canvas and the painter, with soft light from the opposite side.",
|
| 50 |
+
"right": "Ladies-in-waiting from the side, detailed dresses, dog in profile, and portraits on the dark wall.",
|
| 51 |
+
"top": "Ceiling with wooden beams and chandeliers, lit softly.",
|
| 52 |
+
"bottom": "Floor with shadows, the princess's gown, and the resting dog."
|
| 53 |
+
},
|
| 54 |
+
"Van-Gogh":{
|
| 55 |
+
"img":"https://huggingface.co/zimhe/SpatialDiffusion/resolve/main/examples/vangoh.png",
|
| 56 |
+
"global": "A cozy bedroom in the style of Vincent van Gogh, The scene features rustic wooden furniture, a bed with a red pillow and yellow blanket,",
|
| 57 |
+
"front": "The Bedroom in Arles by Vincent van Gogh",
|
| 58 |
+
"back": "wall with door and window, with a view of the outside",
|
| 59 |
+
"left": "wooden chairs with yellow cushions, a small wooden table with a water pitcher and a mirror above it,",
|
| 60 |
+
"right": "The walls are painted with a vibrant blue color, with paintings and decorations hanging slightly tilted",
|
| 61 |
+
"top": "a white ceiling with an old lighting, a window with green shutters letting in warm daylight",
|
| 62 |
+
"bottom": "The wooden floor has a textured, painterly look"
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
}
|
scripts/__pycache__/analysis.cpython-312.pyc
ADDED
|
Binary file (3.94 kB). View file
|
|
|
scripts/__pycache__/cubemap_dataset.cpython-310.pyc
ADDED
|
Binary file (2.53 kB). View file
|
|
|
scripts/__pycache__/cubemap_dataset.cpython-312.pyc
ADDED
|
Binary file (7.77 kB). View file
|
|
|
scripts/__pycache__/cubemap_diffusion_pipeline.cpython-312.pyc
ADDED
|
Binary file (26.1 kB). View file
|
|
|
scripts/__pycache__/cubemap_unet.cpython-310.pyc
ADDED
|
Binary file (3.33 kB). View file
|
|
|
scripts/__pycache__/cubemap_unet.cpython-312.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
scripts/__pycache__/cubemap_unet_attention.cpython-310.pyc
ADDED
|
Binary file (2.17 kB). View file
|
|
|
scripts/__pycache__/cubemap_unet_attention.cpython-312.pyc
ADDED
|
Binary file (4.77 kB). View file
|
|
|
scripts/__pycache__/cubemap_unet_attention_processor.cpython-312.pyc
ADDED
|
Binary file (3.99 kB). View file
|
|
|
scripts/__pycache__/cubemap_unet_attention_simple.cpython-312.pyc
ADDED
|
Binary file (3.87 kB). View file
|
|
|
scripts/__pycache__/cubemap_vae.cpython-310.pyc
ADDED
|
Binary file (6.38 kB). View file
|
|
|
scripts/__pycache__/cubemap_vae.cpython-312.pyc
ADDED
|
Binary file (8.71 kB). View file
|
|
|
scripts/__pycache__/preprocess.cpython-310.pyc
ADDED
|
Binary file (2.7 kB). View file
|
|
|
scripts/__pycache__/train_cubemap_args.cpython-310.pyc
ADDED
|
Binary file (8.46 kB). View file
|
|
|
scripts/__pycache__/train_cubemap_args.cpython-312.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
scripts/__pycache__/train_cubemap_unet.cpython-312.pyc
ADDED
|
Binary file (31.7 kB). View file
|
|
|
scripts/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (553 Bytes). View file
|
|
|
scripts/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
scripts/cubemap_diffusion_pipeline.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline,StableDiffusionPipelineOutput
|
| 4 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 5 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 6 |
+
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
| 7 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import retrieve_timesteps
|
| 8 |
+
from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
|
| 9 |
+
from .utils import FACES,generate_cubemap_uv
|
| 10 |
+
from diffusers import AutoPipelineForInpainting
|
| 11 |
+
|
| 12 |
+
DEFAULT_VIEW_PROMPT = {
|
| 13 |
+
"front": "view to the front, looking forward",
|
| 14 |
+
"back": "view to the back, looking backward",
|
| 15 |
+
"left": "view to the left side, looking left",
|
| 16 |
+
"right": "view to the right side, looking right",
|
| 17 |
+
"top": "view to above, looking upward, ceiling or sky",
|
| 18 |
+
"bottom": "view to below, looking downward, floor or ground",
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CubemapDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
| 24 |
+
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, image_encoder = None, requires_safety_checker = True):
|
| 25 |
+
super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, image_encoder, requires_safety_checker)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def prepare_cubemap_mask_condition(self,width: int, height: int, batch_size=6) -> torch.Tensor:
|
| 29 |
+
# 初始化全部为 0 的 tensor,形状为 (6, 1, height, width)
|
| 30 |
+
mask = torch.ones(size=(batch_size, 1, height, width), dtype=torch.float32)
|
| 31 |
+
# 将第一张 mask 置为 1
|
| 32 |
+
mask[0] = 0
|
| 33 |
+
return mask
|
| 34 |
+
|
| 35 |
+
@torch.no_grad()
|
| 36 |
+
def __call__(
|
| 37 |
+
self,
|
| 38 |
+
global_prompt: str = None,
|
| 39 |
+
per_face_prompts: Optional[Dict[str, str]] = None,
|
| 40 |
+
image: PipelineImageInput = None,
|
| 41 |
+
mask_image: PipelineImageInput = None,
|
| 42 |
+
masked_image_latents: torch.Tensor = None,
|
| 43 |
+
height: Optional[int] = None,
|
| 44 |
+
width: Optional[int] = None,
|
| 45 |
+
padding_mask_crop: Optional[int] = None,
|
| 46 |
+
strength: float = 1.0,
|
| 47 |
+
num_inference_steps: int = 50,
|
| 48 |
+
timesteps: List[int] = None,
|
| 49 |
+
sigmas: List[float] = None,
|
| 50 |
+
guidance_scale: float = 7.5,
|
| 51 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 52 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 53 |
+
eta: float = 0.0,
|
| 54 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 55 |
+
latents: Optional[torch.Tensor] = None,
|
| 56 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 57 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 58 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 59 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 60 |
+
output_type: Optional[str] = "pil",
|
| 61 |
+
return_dict: bool = True,
|
| 62 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 63 |
+
clip_skip: int = None,
|
| 64 |
+
callback_on_step_end: Optional[
|
| 65 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 66 |
+
] = None,
|
| 67 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 68 |
+
**kwargs,
|
| 69 |
+
):
|
| 70 |
+
r"""
|
| 71 |
+
The call function to the pipeline for generation.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 75 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 76 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 77 |
+
`Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
|
| 78 |
+
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
|
| 79 |
+
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
|
| 80 |
+
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
|
| 81 |
+
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
|
| 82 |
+
if passing latents directly it is not encoded again.
|
| 83 |
+
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 84 |
+
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
|
| 85 |
+
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
| 86 |
+
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
| 87 |
+
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
| 88 |
+
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
| 89 |
+
1)`, or `(H, W)`.
|
| 90 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 91 |
+
The height in pixels of the generated image.
|
| 92 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 93 |
+
The width in pixels of the generated image.
|
| 94 |
+
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
| 95 |
+
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
|
| 96 |
+
image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
|
| 97 |
+
with the same aspect ration of the image and contains all masked area, and then expand that area based
|
| 98 |
+
on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
|
| 99 |
+
resizing to the original image size for inpainting. This is useful when the masked area is small while
|
| 100 |
+
the image is large and contain information irrelevant for inpainting, such as background.
|
| 101 |
+
strength (`float`, *optional*, defaults to 1.0):
|
| 102 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
| 103 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
| 104 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
| 105 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
| 106 |
+
essentially ignores `image`.
|
| 107 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 108 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 109 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 110 |
+
timesteps (`List[int]`, *optional*):
|
| 111 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 112 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 113 |
+
passed will be used. Must be in descending order.
|
| 114 |
+
sigmas (`List[float]`, *optional*):
|
| 115 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 116 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 117 |
+
will be used.
|
| 118 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 119 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 120 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 121 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 122 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 123 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 124 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 125 |
+
The number of images to generate per prompt.
|
| 126 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 127 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 128 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 129 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 130 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 131 |
+
generation deterministic.
|
| 132 |
+
latents (`torch.Tensor`, *optional*):
|
| 133 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 134 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 135 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 136 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 137 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 138 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 139 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 140 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 141 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 142 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 143 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 144 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 145 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 146 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 147 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 148 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 149 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 150 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 151 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 152 |
+
plain tuple.
|
| 153 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 154 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 155 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 156 |
+
clip_skip (`int`, *optional*):
|
| 157 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 158 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 159 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 160 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 161 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 162 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 163 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 164 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 165 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 166 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 167 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 168 |
+
Examples:
|
| 169 |
+
|
| 170 |
+
```py
|
| 171 |
+
>>> import PIL
|
| 172 |
+
>>> import requests
|
| 173 |
+
>>> import torch
|
| 174 |
+
>>> from io import BytesIO
|
| 175 |
+
|
| 176 |
+
>>> from diffusers import StableDiffusionInpaintPipeline
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
>>> def download_image(url):
|
| 180 |
+
... response = requests.get(url)
|
| 181 |
+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
| 185 |
+
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
| 186 |
+
|
| 187 |
+
>>> init_image = download_image(img_url).resize((512, 512))
|
| 188 |
+
>>> mask_image = download_image(mask_url).resize((512, 512))
|
| 189 |
+
|
| 190 |
+
>>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 191 |
+
... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
|
| 192 |
+
... )
|
| 193 |
+
>>> pipe = pipe.to("cuda")
|
| 194 |
+
|
| 195 |
+
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
| 196 |
+
>>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 201 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 202 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 203 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 204 |
+
"not-safe-for-work" (nsfw) content.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
callback = kwargs.pop("callback", None)
|
| 208 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 209 |
+
|
| 210 |
+
if callback is not None:
|
| 211 |
+
deprecate(
|
| 212 |
+
"callback",
|
| 213 |
+
"1.0.0",
|
| 214 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 215 |
+
)
|
| 216 |
+
if callback_steps is not None:
|
| 217 |
+
deprecate(
|
| 218 |
+
"callback_steps",
|
| 219 |
+
"1.0.0",
|
| 220 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 224 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 225 |
+
|
| 226 |
+
# 0. Default height and width to unet
|
| 227 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 228 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 229 |
+
|
| 230 |
+
# # 1. Check inputs
|
| 231 |
+
# self.check_inputs(
|
| 232 |
+
# prompt,
|
| 233 |
+
# image,
|
| 234 |
+
# mask_image,
|
| 235 |
+
# height,
|
| 236 |
+
# width,
|
| 237 |
+
# strength,
|
| 238 |
+
# callback_steps,
|
| 239 |
+
# output_type,
|
| 240 |
+
# negative_prompt,
|
| 241 |
+
# prompt_embeds,
|
| 242 |
+
# negative_prompt_embeds,
|
| 243 |
+
# ip_adapter_image,
|
| 244 |
+
# ip_adapter_image_embeds,
|
| 245 |
+
# callback_on_step_end_tensor_inputs,
|
| 246 |
+
# padding_mask_crop,
|
| 247 |
+
# )
|
| 248 |
+
|
| 249 |
+
self._guidance_scale = guidance_scale
|
| 250 |
+
self._clip_skip = clip_skip
|
| 251 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 252 |
+
self._interrupt = False
|
| 253 |
+
|
| 254 |
+
batch_size=6
|
| 255 |
+
|
| 256 |
+
if global_prompt is None and per_face_prompts is None:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
"Provide either `global prompt` or `per_face_prompts`, or both. Can't leave them both empty"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
prompt=[]
|
| 262 |
+
|
| 263 |
+
# 2. Define call parameters
|
| 264 |
+
if global_prompt is not None and per_face_prompts is None:
|
| 265 |
+
prompt=[f"{global_prompt},{DEFAULT_VIEW_PROMPT[f]}" for f in FACES]
|
| 266 |
+
elif global_prompt is not None and per_face_prompts is not None:
|
| 267 |
+
for f in FACES:
|
| 268 |
+
if f in per_face_prompts.keys():
|
| 269 |
+
prompt_text=f"{global_prompt},{per_face_prompts[f]}"
|
| 270 |
+
prompt.append(prompt_text)
|
| 271 |
+
else:
|
| 272 |
+
prompt.append(f"{global_prompt},{DEFAULT_VIEW_PROMPT[f]}")
|
| 273 |
+
|
| 274 |
+
negative_prompt_list = None
|
| 275 |
+
if negative_prompt is not None:
|
| 276 |
+
if isinstance(negative_prompt, list):
|
| 277 |
+
if len(negative_prompt) == batch_size:
|
| 278 |
+
negative_prompt_list = negative_prompt
|
| 279 |
+
else:
|
| 280 |
+
raise ValueError(
|
| 281 |
+
f"Length of `negative_prompt` list ({len(negative_prompt)}) does not match `batch_size` ({batch_size})."
|
| 282 |
+
)
|
| 283 |
+
else:
|
| 284 |
+
negative_prompt_list = [negative_prompt for _ in range(batch_size)]
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
device = self._execution_device
|
| 288 |
+
|
| 289 |
+
# 3. Encode input prompt
|
| 290 |
+
text_encoder_lora_scale = (
|
| 291 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
| 292 |
+
)
|
| 293 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 294 |
+
prompt,
|
| 295 |
+
device,
|
| 296 |
+
num_images_per_prompt,
|
| 297 |
+
self.do_classifier_free_guidance,
|
| 298 |
+
negative_prompt=negative_prompt_list,
|
| 299 |
+
lora_scale=text_encoder_lora_scale,
|
| 300 |
+
clip_skip=self.clip_skip,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 304 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 305 |
+
# to avoid doing two forward passes
|
| 306 |
+
if self.do_classifier_free_guidance:
|
| 307 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 308 |
+
|
| 309 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 310 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 311 |
+
ip_adapter_image,
|
| 312 |
+
ip_adapter_image_embeds,
|
| 313 |
+
device,
|
| 314 |
+
batch_size * num_images_per_prompt,
|
| 315 |
+
self.do_classifier_free_guidance,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# 4. set timesteps
|
| 319 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 320 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 321 |
+
)
|
| 322 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
| 323 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
| 324 |
+
)
|
| 325 |
+
# check that number of inference steps is not < 1 - as this doesn't make sense
|
| 326 |
+
if num_inference_steps < 1:
|
| 327 |
+
raise ValueError(
|
| 328 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
| 329 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
| 330 |
+
)
|
| 331 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
| 332 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 333 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
| 334 |
+
is_strength_max = strength == 1.0
|
| 335 |
+
|
| 336 |
+
# 5. Preprocess mask and image
|
| 337 |
+
|
| 338 |
+
# if padding_mask_crop is not None:
|
| 339 |
+
# crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
| 340 |
+
# resize_mode = "fill"
|
| 341 |
+
# else:
|
| 342 |
+
# crops_coords = None
|
| 343 |
+
# resize_mode = "default"
|
| 344 |
+
|
| 345 |
+
original_image = image
|
| 346 |
+
cond_image=self.image_processor.preprocess(image,height=height,width=width)
|
| 347 |
+
cond_image=cond_image.to(dtype=torch.float32)
|
| 348 |
+
|
| 349 |
+
empty_face=torch.zeros_like(cond_image,dtype=torch.float32)
|
| 350 |
+
init_image = [cond_image]+[empty_face for _ in range(5)]
|
| 351 |
+
|
| 352 |
+
#stack the condition image with empty faces to make a tensor stack of shape (6,3,H,W)
|
| 353 |
+
init_image=torch.stack(init_image,dim=0)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
# 6. Prepare latent variables
|
| 357 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 358 |
+
num_channels_unet = self.unet.config.in_channels
|
| 359 |
+
return_image_latents = num_channels_unet == 4
|
| 360 |
+
|
| 361 |
+
latents_outputs = self.prepare_latents(
|
| 362 |
+
batch_size,
|
| 363 |
+
num_channels_latents,
|
| 364 |
+
height,
|
| 365 |
+
width,
|
| 366 |
+
prompt_embeds.dtype,
|
| 367 |
+
device,
|
| 368 |
+
generator,
|
| 369 |
+
latents,
|
| 370 |
+
timestep=latent_timestep,
|
| 371 |
+
is_strength_max=is_strength_max,
|
| 372 |
+
return_noise=True,
|
| 373 |
+
return_image_latents=False
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
if return_image_latents:
|
| 377 |
+
latents, noise, image_latents = latents_outputs
|
| 378 |
+
else:
|
| 379 |
+
latents, noise = latents_outputs
|
| 380 |
+
|
| 381 |
+
mask_condition=self.prepare_cubemap_mask_condition(width,height)
|
| 382 |
+
masked_image=init_image
|
| 383 |
+
|
| 384 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
| 385 |
+
mask_condition,
|
| 386 |
+
masked_image,
|
| 387 |
+
batch_size,
|
| 388 |
+
height,
|
| 389 |
+
width,
|
| 390 |
+
prompt_embeds.dtype,
|
| 391 |
+
device,
|
| 392 |
+
generator,
|
| 393 |
+
self.do_classifier_free_guidance,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
uv_maps=generate_cubemap_uv(height//self.vae_scale_factor,width//self.vae_scale_factor)
|
| 397 |
+
|
| 398 |
+
uv_channels=torch.stack([uv_maps[face] for face in FACES]).to(dtype=prompt_embeds.dtype,device=device)
|
| 399 |
+
uv_channel_input= torch.cat([uv_channels] * 2) if self.do_classifier_free_guidance else uv_channels
|
| 400 |
+
|
| 401 |
+
# # 8. Check that sizes of mask, masked image and latents match
|
| 402 |
+
# if num_channels_unet == 9:
|
| 403 |
+
# # default case for runwayml/stable-diffusion-inpainting
|
| 404 |
+
# num_channels_mask = mask.shape[1]
|
| 405 |
+
# num_channels_masked_image = masked_image_latents.shape[1]
|
| 406 |
+
# if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
| 407 |
+
# raise ValueError(
|
| 408 |
+
# f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
| 409 |
+
# f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
| 410 |
+
# f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
| 411 |
+
# f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
| 412 |
+
# " `pipeline.unet` or your `mask_image` or `image` input."
|
| 413 |
+
# )
|
| 414 |
+
# elif num_channels_unet != 4:
|
| 415 |
+
# raise ValueError(
|
| 416 |
+
# f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
| 417 |
+
# )
|
| 418 |
+
|
| 419 |
+
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 420 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 421 |
+
|
| 422 |
+
# 9.1 Add image embeds for IP-Adapter
|
| 423 |
+
added_cond_kwargs = (
|
| 424 |
+
{"image_embeds": image_embeds}
|
| 425 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
| 426 |
+
else None
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# 9.2 Optionally get Guidance Scale Embedding
|
| 430 |
+
timestep_cond = None
|
| 431 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 432 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 433 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 434 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 435 |
+
).to(device=device, dtype=latents.dtype)
|
| 436 |
+
|
| 437 |
+
# 10. Denoising loop
|
| 438 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 439 |
+
self._num_timesteps = len(timesteps)
|
| 440 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 441 |
+
for i, t in enumerate(timesteps):
|
| 442 |
+
if self.interrupt:
|
| 443 |
+
continue
|
| 444 |
+
|
| 445 |
+
# expand the latents if we are doing classifier free guidance
|
| 446 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 447 |
+
|
| 448 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
| 449 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 450 |
+
|
| 451 |
+
#print("Unet Channels:", num_channels_unet)
|
| 452 |
+
|
| 453 |
+
if num_channels_unet == 11:
|
| 454 |
+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents,uv_channel_input], dim=1)
|
| 455 |
+
#print("Latent Input Shape:",latent_model_input.shape)
|
| 456 |
+
|
| 457 |
+
# predict the noise residual
|
| 458 |
+
noise_pred = self.unet(
|
| 459 |
+
latent_model_input,
|
| 460 |
+
t,
|
| 461 |
+
encoder_hidden_states=prompt_embeds,
|
| 462 |
+
timestep_cond=timestep_cond,
|
| 463 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 464 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 465 |
+
return_dict=False,
|
| 466 |
+
)[0]
|
| 467 |
+
|
| 468 |
+
# perform guidance
|
| 469 |
+
if self.do_classifier_free_guidance:
|
| 470 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 471 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 472 |
+
|
| 473 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 474 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 475 |
+
if num_channels_unet == 4:
|
| 476 |
+
init_latents_proper = image_latents
|
| 477 |
+
if self.do_classifier_free_guidance:
|
| 478 |
+
init_mask, _ = mask.chunk(2)
|
| 479 |
+
else:
|
| 480 |
+
init_mask = mask
|
| 481 |
+
|
| 482 |
+
if i < len(timesteps) - 1:
|
| 483 |
+
noise_timestep = timesteps[i + 1]
|
| 484 |
+
init_latents_proper = self.scheduler.add_noise(
|
| 485 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
| 489 |
+
|
| 490 |
+
if callback_on_step_end is not None:
|
| 491 |
+
callback_kwargs = {}
|
| 492 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 493 |
+
callback_kwargs[k] = locals()[k]
|
| 494 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 495 |
+
|
| 496 |
+
latents = callback_outputs.pop("latents", latents)
|
| 497 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 498 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 499 |
+
mask = callback_outputs.pop("mask", mask)
|
| 500 |
+
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
| 501 |
+
|
| 502 |
+
# call the callback, if provided
|
| 503 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 504 |
+
progress_bar.update()
|
| 505 |
+
if callback is not None and i % callback_steps == 0:
|
| 506 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 507 |
+
callback(step_idx, t, latents)
|
| 508 |
+
|
| 509 |
+
if not output_type == "latent":
|
| 510 |
+
condition_kwargs = {}
|
| 511 |
+
if isinstance(self.vae, AsymmetricAutoencoderKL):
|
| 512 |
+
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
|
| 513 |
+
init_image_condition = init_image.clone()
|
| 514 |
+
init_image = self._encode_vae_image(init_image, generator=generator)
|
| 515 |
+
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
|
| 516 |
+
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
|
| 517 |
+
image = self.vae.decode(
|
| 518 |
+
latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
|
| 519 |
+
)[0]
|
| 520 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
else:
|
| 524 |
+
image = latents
|
| 525 |
+
has_nsfw_concept = None
|
| 526 |
+
|
| 527 |
+
if has_nsfw_concept is None:
|
| 528 |
+
do_denormalize = [True] * image.shape[0]
|
| 529 |
+
else:
|
| 530 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
input_img=self.image_processor.postprocess(cond_image,output_type=output_type,do_denormalize=do_denormalize)[0]
|
| 537 |
+
image[0]=input_img
|
| 538 |
+
|
| 539 |
+
# if padding_mask_crop is not None:
|
| 540 |
+
# image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
| 541 |
+
|
| 542 |
+
# Offload all models
|
| 543 |
+
self.maybe_free_model_hooks()
|
| 544 |
+
|
| 545 |
+
if not return_dict:
|
| 546 |
+
return (image, has_nsfw_concept)
|
| 547 |
+
|
| 548 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
| 549 |
+
|
| 550 |
+
|
scripts/cubemap_unet.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from diffusers import UNet2DConditionModel
|
| 4 |
+
from diffusers.models.attention import BasicTransformerBlock
|
| 5 |
+
from .cubemap_unet_attention_processor import InflatedAttentionProcessor
|
| 6 |
+
|
| 7 |
+
class CubemapUNet(UNet2DConditionModel):
|
| 8 |
+
|
| 9 |
+
def __init__(self, pretrained_unet, in_channels=11):
|
| 10 |
+
"""
|
| 11 |
+
1. **先加载 `pretrained_unet` 的默认通道数**
|
| 12 |
+
2. **完全加载 UNet 预训练权重**
|
| 13 |
+
3. **扩展 `conv_in` 和 `conv_out`**
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# 从 `pretrained_unet.config` 复制所有参数,并修改 `in_channels`
|
| 17 |
+
unet_config = {**pretrained_unet.config} # 复制字典,防止修改原模型
|
| 18 |
+
|
| 19 |
+
super().__init__(**unet_config) # 这里直接传入所有参数
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Step 2: 完全加载 `pretrained_unet` 的权重
|
| 23 |
+
print("开始加载预训练权重")
|
| 24 |
+
self.load_state_dict(pretrained_unet.state_dict(), strict=True)
|
| 25 |
+
print("✅ UNet 预训练权重加载成功!")
|
| 26 |
+
|
| 27 |
+
# Step 3: **扩展 `conv_in`通道为11
|
| 28 |
+
self._expand_conv_in(pretrained_unet, in_channels)
|
| 29 |
+
self.register_to_config(in_channels=in_channels)
|
| 30 |
+
|
| 31 |
+
cubemap_attn_processor=InflatedAttentionProcessor()
|
| 32 |
+
|
| 33 |
+
self._modify_attn_processor(cubemap_attn_processor)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _expand_conv_in(self, pretrained_unet, new_in_channels):
|
| 38 |
+
"""扩展 `conv_in` 以适应新的输入通道"""
|
| 39 |
+
old_conv_in = pretrained_unet.conv_in
|
| 40 |
+
old_weight = old_conv_in.weight # [out_channels, 4, kernel, kernel]
|
| 41 |
+
|
| 42 |
+
# 创建新的 `conv_in`
|
| 43 |
+
new_conv_in = nn.Conv2d(
|
| 44 |
+
new_in_channels,
|
| 45 |
+
old_conv_in.out_channels,
|
| 46 |
+
kernel_size=old_conv_in.kernel_size,
|
| 47 |
+
stride=old_conv_in.stride,
|
| 48 |
+
padding=old_conv_in.padding
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
old_channels=old_weight.shape[1]
|
| 52 |
+
|
| 53 |
+
# 复制前 `4` 个通道的权重
|
| 54 |
+
new_weight = torch.zeros((new_conv_in.out_channels, new_in_channels, *old_conv_in.kernel_size))
|
| 55 |
+
new_weight[:, :old_channels, :, :] = old_weight # 复制 4 通道
|
| 56 |
+
|
| 57 |
+
# 随机初始化新增通道
|
| 58 |
+
new_weight[:, old_channels:, :, :] = torch.randn_like(new_weight[:, old_channels:, :, :]) * 0.01
|
| 59 |
+
|
| 60 |
+
new_conv_in.weight = nn.Parameter(new_weight)
|
| 61 |
+
if old_conv_in.bias is not None:
|
| 62 |
+
new_conv_in.bias = nn.Parameter(old_conv_in.bias.clone())
|
| 63 |
+
|
| 64 |
+
self.conv_in = new_conv_in
|
| 65 |
+
print(f"✅ `conv_in` 扩展成功!新输入通道: {new_in_channels}")
|
| 66 |
+
|
| 67 |
+
def _modify_attn_processor(self,processor):
|
| 68 |
+
|
| 69 |
+
for name, module in self.named_modules():
|
| 70 |
+
if isinstance(module, BasicTransformerBlock):
|
| 71 |
+
if hasattr(module, 'attn1'):
|
| 72 |
+
module.attn1.set_processor(processor)
|
| 73 |
+
|
| 74 |
+
if hasattr(module, 'attn2'):
|
| 75 |
+
module.attn2.set_processor(processor)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
scripts/cubemap_unet_attention_processor.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from diffusers.models.attention import BasicTransformerBlock
|
| 5 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor
|
| 6 |
+
from flash_attn.flash_attn_interface import flash_attn_func
|
| 7 |
+
|
| 8 |
+
class InflatedAttentionProcessor(AttnProcessor):
|
| 9 |
+
def __call__(
|
| 10 |
+
self,
|
| 11 |
+
attn: Attention,
|
| 12 |
+
hidden_states: torch.Tensor,
|
| 13 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 14 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 15 |
+
temb: Optional[torch.Tensor] = None,
|
| 16 |
+
num_views: int = 6, # 例如 CubeMap 有 6 个视角
|
| 17 |
+
*args,
|
| 18 |
+
**kwargs,
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
实现 CubeDiff 论文中的 Inflated Attention:
|
| 22 |
+
- 将输入 `B, N, C` 转换为 `B, F*N, C`
|
| 23 |
+
- 在 `F*N` 维度上进行 Self-Attention
|
| 24 |
+
"""
|
| 25 |
+
residual = hidden_states
|
| 26 |
+
|
| 27 |
+
# 1️⃣ 预处理
|
| 28 |
+
if attn.spatial_norm is not None:
|
| 29 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 30 |
+
|
| 31 |
+
input_ndim = hidden_states.ndim
|
| 32 |
+
|
| 33 |
+
if input_ndim == 4:
|
| 34 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 35 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
BXF, N, C = hidden_states.shape # 原始注意力输入
|
| 39 |
+
# 2️⃣ **变换 `B, N, C → B, F*N, C`**
|
| 40 |
+
F = num_views
|
| 41 |
+
B=BXF//F
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# 3️⃣ **标准 Attention 计算**
|
| 45 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, hidden_states.shape[1], B)
|
| 46 |
+
|
| 47 |
+
if attn.group_norm is not None:
|
| 48 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 49 |
+
|
| 50 |
+
is_self_attn=False
|
| 51 |
+
|
| 52 |
+
if encoder_hidden_states is None:
|
| 53 |
+
hidden_states = hidden_states.view(B, F, N, C)
|
| 54 |
+
hidden_states = hidden_states.reshape(B, F * N, C)
|
| 55 |
+
encoder_hidden_states = hidden_states
|
| 56 |
+
is_self_attn=True
|
| 57 |
+
elif attn.norm_cross:
|
| 58 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 59 |
+
|
| 60 |
+
query = attn.to_q(hidden_states)
|
| 61 |
+
key = attn.to_k(encoder_hidden_states)
|
| 62 |
+
value = attn.to_v(encoder_hidden_states)
|
| 63 |
+
|
| 64 |
+
query = attn.head_to_batch_dim(query,out_dim=4).permute(0,2,1,3)
|
| 65 |
+
key = attn.head_to_batch_dim(key,out_dim=4).permute(0,2,1,3)
|
| 66 |
+
value = attn.head_to_batch_dim(value,out_dim=4).permute(0,2,1,3)
|
| 67 |
+
|
| 68 |
+
hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
| 69 |
+
B,L,H,D=hidden_states.shape
|
| 70 |
+
hidden_states = hidden_states.view(B,L,H*D)
|
| 71 |
+
|
| 72 |
+
# query = attn.head_to_batch_dim(query)
|
| 73 |
+
# key = attn.head_to_batch_dim(key)
|
| 74 |
+
# value = attn.head_to_batch_dim(value)
|
| 75 |
+
|
| 76 |
+
# attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 77 |
+
# hidden_states = torch.bmm(attention_probs, value)
|
| 78 |
+
# hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 79 |
+
|
| 80 |
+
# 4️⃣ **线性投影 & Dropout**
|
| 81 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 82 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 83 |
+
|
| 84 |
+
if is_self_attn:
|
| 85 |
+
# 5️⃣ **还原形状 `B, F*N, C → B, N, C`**
|
| 86 |
+
hidden_states = hidden_states.view(B, F, N, C)
|
| 87 |
+
hidden_states = hidden_states.reshape(BXF, N, C)
|
| 88 |
+
|
| 89 |
+
if input_ndim == 4:
|
| 90 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 91 |
+
|
| 92 |
+
if attn.residual_connection:
|
| 93 |
+
hidden_states = hidden_states + residual
|
| 94 |
+
|
| 95 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 96 |
+
|
| 97 |
+
return hidden_states
|
| 98 |
+
|
| 99 |
+
|
scripts/cubemap_vae.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers.models import AutoencoderKL
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torchvision.transforms import ToPILImage
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
class SynchronizedGroupNorm(nn.Module):
|
| 8 |
+
def __init__(self, original_group_norm: nn.GroupNorm, num_views: int = 6):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.num_views = num_views
|
| 11 |
+
|
| 12 |
+
# 继承原始分组参数
|
| 13 |
+
self.num_groups = original_group_norm.num_groups
|
| 14 |
+
self.num_channels = original_group_norm.num_channels
|
| 15 |
+
self.eps = original_group_norm.eps
|
| 16 |
+
|
| 17 |
+
# 按照通道组重构参数
|
| 18 |
+
self.group_size = self.num_channels // self.num_groups
|
| 19 |
+
|
| 20 |
+
# 继承原始参数(保持每个分组的仿射变换)
|
| 21 |
+
self.weight = nn.Parameter(original_group_norm.weight.detach().view(self.num_groups, self.group_size))
|
| 22 |
+
self.bias = nn.Parameter(original_group_norm.bias.detach().view(self.num_groups, self.group_size))
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor):
|
| 25 |
+
""" 兼容 3D (B, C, D) 和 4D (B, C, H, W) 输入 """
|
| 26 |
+
|
| 27 |
+
#print(f"Input shape: {x.shape}") # Debugging
|
| 28 |
+
|
| 29 |
+
# 获取输入的形状信息
|
| 30 |
+
BxT, C = x.shape[:2] # 只取前两个维度
|
| 31 |
+
B = BxT // self.num_views # 计算 batch 维度
|
| 32 |
+
|
| 33 |
+
# 处理 3D (B, C, D) 输入
|
| 34 |
+
if x.dim() == 3:
|
| 35 |
+
D = x.shape[2]
|
| 36 |
+
x = x.view(B, self.num_views, self.num_groups, self.group_size, D)
|
| 37 |
+
|
| 38 |
+
# 计算 GroupNorm
|
| 39 |
+
mean = x.mean(dim=(1, 3, 4), keepdim=True)
|
| 40 |
+
var = x.var(dim=(1, 3, 4), keepdim=True, unbiased=False)
|
| 41 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
| 42 |
+
# **修正 weight 和 bias 的形状**
|
| 43 |
+
weight = self.weight.view(1, self.num_groups, self.group_size, 1)
|
| 44 |
+
bias = self.bias.view(1, self.num_groups, self.group_size, 1)
|
| 45 |
+
x = x * weight + bias
|
| 46 |
+
|
| 47 |
+
# 还原形状
|
| 48 |
+
return x.view(BxT, C, D)
|
| 49 |
+
|
| 50 |
+
# 处理 4D (B, C, H, W) 输入
|
| 51 |
+
elif x.dim() == 4:
|
| 52 |
+
H, W = x.shape[2:]
|
| 53 |
+
x = x.view(B, self.num_views, self.num_groups, self.group_size, H, W)
|
| 54 |
+
|
| 55 |
+
# 计算 GroupNorm
|
| 56 |
+
mean = x.mean(dim=(1, 3, 4, 5), keepdim=True)
|
| 57 |
+
var = x.var(dim=(1, 3, 4, 5), keepdim=True, unbiased=False)
|
| 58 |
+
x = (x - mean) / torch.sqrt(var + self.eps)
|
| 59 |
+
# **修正 weight 和 bias 的形状**
|
| 60 |
+
weight = self.weight.view(1, self.num_groups, self.group_size, 1, 1)
|
| 61 |
+
bias = self.bias.view(1, self.num_groups, self.group_size, 1, 1)
|
| 62 |
+
x = x * weight + bias
|
| 63 |
+
|
| 64 |
+
# 还原形状
|
| 65 |
+
return x.view(BxT, C, H, W)
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unsupported input shape: {x.shape}, expected 3D (B, C, D) or 4D (B, C, H, W).")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class CubemapVAE(AutoencoderKL):
|
| 72 |
+
def __init__(self, pretrained_vae, num_views=6, in_channels=3,image_size=512):
|
| 73 |
+
super().__init__( # 继承自 AutoencoderKL
|
| 74 |
+
act_fn="silu",
|
| 75 |
+
block_out_channels=[128, 256, 512, 512],
|
| 76 |
+
down_block_types=[
|
| 77 |
+
"DownEncoderBlock2D",
|
| 78 |
+
"DownEncoderBlock2D",
|
| 79 |
+
"DownEncoderBlock2D",
|
| 80 |
+
"DownEncoderBlock2D"
|
| 81 |
+
],
|
| 82 |
+
up_block_types=[
|
| 83 |
+
"UpDecoderBlock2D",
|
| 84 |
+
"UpDecoderBlock2D",
|
| 85 |
+
"UpDecoderBlock2D",
|
| 86 |
+
"UpDecoderBlock2D"
|
| 87 |
+
],
|
| 88 |
+
latent_channels=pretrained_vae.config.latent_channels,
|
| 89 |
+
in_channels=in_channels,
|
| 90 |
+
out_channels=in_channels
|
| 91 |
+
)
|
| 92 |
+
self.num_views = num_views
|
| 93 |
+
self.in_channels = in_channels
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# --- 替换关键模块,适配 Cubemap ---
|
| 97 |
+
# 原 AutoencoderKL 的编码器不够灵活,直接覆盖编码器
|
| 98 |
+
#self.encoder = CubemapEncoder(pretrained_encoder=pretrained_vae.encoder,num_views=num_views, in_channels=in_channels)
|
| 99 |
+
#self.decoder = CubemapDecoder(pretrained_decoder=pretrained_vae.decoder, num_views=num_views, out_channels=in_channels,in_channels=4)
|
| 100 |
+
self.encoder=pretrained_vae.encoder
|
| 101 |
+
self.decoder=pretrained_vae.decoder
|
| 102 |
+
self.quant_conv=pretrained_vae.quant_conv
|
| 103 |
+
self.post_quant_conv=pretrained_vae.post_quant_conv
|
| 104 |
+
# 将原 GroupNorm 替换为同步 GroupNorm
|
| 105 |
+
replace_group_norm_with_sgn(self, num_views=num_views)
|
| 106 |
+
|
| 107 |
+
def encode(self, images,return_dict:bool=True):
|
| 108 |
+
batch_size, num_views, num_channels, height, width = images.shape
|
| 109 |
+
images = images.view(batch_size*num_views,num_channels, height, width)
|
| 110 |
+
return super().encode(images,return_dict=return_dict)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def decode(self, latents, return_dict=True, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
自定义 VAE 解码:
|
| 116 |
+
- 去掉 UV 通道 (只保留前 4 个 latent 通道)
|
| 117 |
+
- 调用原始 VAE 解码流程
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
print("Decoder Recieve Latent Shape:", latents.shape)
|
| 121 |
+
# 确�� latents 至少有 4 个通道
|
| 122 |
+
if latents.shape[1] > 4:
|
| 123 |
+
latents = latents[:, :4, :, :] # 只保留前 4 个通道,去掉 UV 通道
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
return super().decode(latents, return_dict=return_dict, **kwargs)
|
| 128 |
+
|
| 129 |
+
def decode_to_tensor(self, latents):
|
| 130 |
+
decoded = self.decode(latents).sample # (B*6, 3, H, W)
|
| 131 |
+
|
| 132 |
+
B = latents.shape[0] // 6
|
| 133 |
+
images = torch.split(decoded, B, dim=0) # 按 batch 拆分
|
| 134 |
+
|
| 135 |
+
return images # Tuple of 6 tensors
|
| 136 |
+
|
| 137 |
+
def decode_to_pil_images(self, latents:Tensor):
|
| 138 |
+
images = self.decode_to_tensor(latents) # 获取 6 张图
|
| 139 |
+
to_pil = ToPILImage()
|
| 140 |
+
|
| 141 |
+
return [to_pil(img[0].cpu().detach()) for img in images] # 转换为 PIL
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def replace_group_norm_with_sgn(model, num_views):
|
| 146 |
+
""" 遍历 model,找到所有 GroupNorm 并替换成 SynchronizedGroupNorm """
|
| 147 |
+
replacements = [] # 先收集要替换的 module 名称
|
| 148 |
+
for name, module in model.named_modules():
|
| 149 |
+
if isinstance(module, nn.GroupNorm):
|
| 150 |
+
replacements.append(name)
|
| 151 |
+
|
| 152 |
+
for name in replacements:
|
| 153 |
+
parent_module, attr_name = get_parent_module(model, name)
|
| 154 |
+
setattr(parent_module, attr_name, SynchronizedGroupNorm(getattr(parent_module, attr_name), num_views))
|
| 155 |
+
|
| 156 |
+
def get_parent_module(model, module_name):
|
| 157 |
+
""" 获取 `module_name` 所在的上一级 module 和属性名称 """
|
| 158 |
+
names = module_name.split(".")
|
| 159 |
+
parent_module = model
|
| 160 |
+
for name in names[:-1]: # 遍历到倒数第二层
|
| 161 |
+
parent_module = getattr(parent_module, name)
|
| 162 |
+
return parent_module, names[-1]
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def flatten_face_names(face_names):
|
| 167 |
+
flat_face_names = []
|
| 168 |
+
for item in face_names:
|
| 169 |
+
if isinstance(item, str): # 直接是字符串
|
| 170 |
+
flat_face_names.append(item)
|
| 171 |
+
elif isinstance(item, (list,tuple)): # 是列表,展开其中的字符串
|
| 172 |
+
flat_face_names.extend(item)
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError(f"Unexpected type in face_names: {type(item)}")
|
| 175 |
+
return flat_face_names
|
| 176 |
+
|
| 177 |
+
|
scripts/utils.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
import torchvision.transforms.functional as TF
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import cv2
|
| 7 |
+
import py360convert
|
| 8 |
+
import argparse
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
from numpy.typing import NDArray
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
# 6 视角定义(前、后、左、右、上、下)
|
| 15 |
+
FACES = ["front", "back", "left", "right", "top", "bottom"]
|
| 16 |
+
|
| 17 |
+
FACE_CAPTION_MAP={
|
| 18 |
+
"front": "caption_front",
|
| 19 |
+
"back": "caption_back",
|
| 20 |
+
"left": "caption_left",
|
| 21 |
+
"right": "caption_right",
|
| 22 |
+
"top": "caption_top",
|
| 23 |
+
"bottom": "caption_bottom"
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
FACE_KEYS_MAP = {
|
| 27 |
+
"front": "F",
|
| 28 |
+
"back": "B",
|
| 29 |
+
"left": "L",
|
| 30 |
+
"right": "R",
|
| 31 |
+
"top": "U",
|
| 32 |
+
"bottom": "D"
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_cubemap_dict(cubemap_path_dict:dict):
|
| 37 |
+
"""
|
| 38 |
+
从字典中加载 Cubemap 图像
|
| 39 |
+
"""
|
| 40 |
+
cubemap_dict = {}
|
| 41 |
+
for face, path in cubemap_path_dict.items():
|
| 42 |
+
image = cv2.imread(path)
|
| 43 |
+
if image is None:
|
| 44 |
+
print(f"❌ 读取失败: {path}")
|
| 45 |
+
continue
|
| 46 |
+
cubemap_dict[FACE_KEYS_MAP[face]] = image
|
| 47 |
+
return cubemap_dict
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def convert_to_cubemap(image, size=512):
|
| 51 |
+
"""
|
| 52 |
+
使用 py360convert 将 equirectangular 全景图转换为 6 视角 Cubemap
|
| 53 |
+
"""
|
| 54 |
+
cubemap_dict = py360convert.e2c(image, face_w=size, mode="bilinear", cube_format="dict")
|
| 55 |
+
return cubemap_dict
|
| 56 |
+
|
| 57 |
+
def to_cubemap_dict(images:list[NDArray]):
|
| 58 |
+
cubemap_dict={}
|
| 59 |
+
for i, face in enumerate(FACES):
|
| 60 |
+
key=FACE_KEYS_MAP[face]
|
| 61 |
+
cubemap_dict[key]=images[i]
|
| 62 |
+
|
| 63 |
+
return cubemap_dict
|
| 64 |
+
|
| 65 |
+
def convert_to_equirectangular(cubemap_dict, width=1024,height=512):
|
| 66 |
+
"""
|
| 67 |
+
使用 py360convert 将 6 视角 Cubemap 转换为 equirectangular 全景图
|
| 68 |
+
"""
|
| 69 |
+
equirectangular_image = py360convert.c2e(cubemap_dict, w=width,h=height, mode="bilinear",cube_format="dict")
|
| 70 |
+
|
| 71 |
+
if equirectangular_image.dtype == np.float32:
|
| 72 |
+
equirectangular_image = np.clip(equirectangular_image * 255, 0, 255).astype(np.uint8)
|
| 73 |
+
|
| 74 |
+
return Image.fromarray(equirectangular_image)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def process_image_e2c(input_path, output_dir, size=512):
|
| 78 |
+
"""
|
| 79 |
+
读取 equirectangular 全景图,转换为 Cubemap 并保存 6 张单独的图像
|
| 80 |
+
"""
|
| 81 |
+
image = cv2.imread(input_path)
|
| 82 |
+
if image is None:
|
| 83 |
+
print(f"❌ 读取失败: {input_path}")
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
# 生成 Cubemap
|
| 89 |
+
cubemap_images = convert_to_cubemap(image, size)
|
| 90 |
+
|
| 91 |
+
print(cubemap_images.keys())
|
| 92 |
+
|
| 93 |
+
# 保存 6 张图像
|
| 94 |
+
for face in FACES:
|
| 95 |
+
output_path = os.path.join(output_dir, f"{face}.png")
|
| 96 |
+
face_key = FACE_KEYS_MAP[face]
|
| 97 |
+
cv2.imwrite(output_path, cubemap_images[face_key])
|
| 98 |
+
print(f"✅ {face} 视角已保存: {output_path}")
|
| 99 |
+
|
| 100 |
+
def process_image_c2e(cubemap_path_dict, output_path, width,height):
|
| 101 |
+
"""
|
| 102 |
+
读取 6 视角 Cubemap,转换为 equirectangular 全景图并保存
|
| 103 |
+
"""
|
| 104 |
+
cubemap_dict=load_cubemap_dict(cubemap_path_dict)
|
| 105 |
+
# 生成 equirectangular 全景图
|
| 106 |
+
equirectangular_image = convert_to_equirectangular(cubemap_dict, width,height)
|
| 107 |
+
cv2.imwrite(output_path, equirectangular_image)
|
| 108 |
+
print(f"✅ 全景图已保存: {output_path}")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def perspective_transform_patch(patch: torch.Tensor, delta):
|
| 113 |
+
"""
|
| 114 |
+
对输入的 patch 使用 torchvision.transforms.functional.perspective 进行透视变换。
|
| 115 |
+
|
| 116 |
+
参数:
|
| 117 |
+
patch: Tensor,形状 (C, H, W),图像 patch
|
| 118 |
+
offset: float,表示左右方向的偏移量(单位:像素),用于定义目标透视变换的 endpoints
|
| 119 |
+
例如:正值表示上边向右平移,下边向左平移;负值则相反。
|
| 120 |
+
|
| 121 |
+
返回:
|
| 122 |
+
transformed: 透视变换后的 patch,Tensor,形状与 patch 相同
|
| 123 |
+
"""
|
| 124 |
+
C, H, W = patch.shape
|
| 125 |
+
|
| 126 |
+
# 定义原始四个角的坐标(顺序为:上左, 上右, 下右, 下左)
|
| 127 |
+
startpoints = [
|
| 128 |
+
[0, 0], # top-left
|
| 129 |
+
[W, 0], # top-right
|
| 130 |
+
[W, H], # bottom-right
|
| 131 |
+
[0, H] # bottom-left
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
endpoints=[[sp_i + d_i for sp_i, d_i in zip(sp, d)] for sp, d in zip(startpoints, delta)]
|
| 135 |
+
# 注意:F.perspective 接受的 startpoints 和 endpoints 应为 List[List[float]]
|
| 136 |
+
# 透视变换支持直接传入 tensor,但这里直接使用 list 即可。
|
| 137 |
+
return TF.perspective(patch, startpoints, endpoints, interpolation=TF.InterpolationMode.BILINEAR)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def stretch_edge_patch(patch,pad_width,edge_key):
|
| 141 |
+
C,H,W=patch.shape
|
| 142 |
+
H_new=H + 2*pad_width
|
| 143 |
+
W_new=W + 2*pad_width
|
| 144 |
+
if edge_key=="top":
|
| 145 |
+
top_edge=TF.resize(patch,(pad_width,W_new))
|
| 146 |
+
delta_top = [
|
| 147 |
+
[0, 0], # top-left
|
| 148 |
+
[0, 0], # top-right
|
| 149 |
+
[-pad_width , 0], # bottom-right
|
| 150 |
+
[pad_width , 0] # bottom-left
|
| 151 |
+
]
|
| 152 |
+
return perspective_transform_patch(top_edge, delta_top)
|
| 153 |
+
|
| 154 |
+
elif edge_key=="bottom":
|
| 155 |
+
bottom_edge=TF.resize(patch,(pad_width,W_new))
|
| 156 |
+
delta_bottom=[
|
| 157 |
+
[pad_width, 0], # top-left
|
| 158 |
+
[-pad_width, 0], # top-right
|
| 159 |
+
[0, 0], # bottom-right
|
| 160 |
+
[0 , 0] # bottom-left
|
| 161 |
+
]
|
| 162 |
+
return perspective_transform_patch(bottom_edge,delta_bottom)
|
| 163 |
+
elif edge_key=="left":
|
| 164 |
+
left_edge=TF.resize(patch,(H_new,pad_width))
|
| 165 |
+
delta_left=[
|
| 166 |
+
[0, 0], # top-left 变为 (offset, 0)
|
| 167 |
+
[0, pad_width], # top-right 变为 (W+offset, 0)
|
| 168 |
+
[0, -pad_width], # bottom-right 变为 (W-offset, H)
|
| 169 |
+
[0 , 0] # bottom-left 变为 (-offset, H)
|
| 170 |
+
]
|
| 171 |
+
return perspective_transform_patch(left_edge,delta_left)
|
| 172 |
+
elif edge_key=="right":
|
| 173 |
+
right_edge=TF.resize(patch,(H_new,pad_width))
|
| 174 |
+
delta_right=[
|
| 175 |
+
[0, pad_width], # top-left 变为 (offset, 0)
|
| 176 |
+
[0, 0], # top-right 变为 (W+offset, 0)
|
| 177 |
+
[0, 0], # bottom-right 变为 (W-offset, H)
|
| 178 |
+
[0 , -pad_width] # bottom-left 变为 (-offset, H)
|
| 179 |
+
]
|
| 180 |
+
return perspective_transform_patch(right_edge,delta_right)
|
| 181 |
+
|
| 182 |
+
# -------------------------------
|
| 183 |
+
# 定义各面拼接函数
|
| 184 |
+
# -------------------------------
|
| 185 |
+
|
| 186 |
+
def pad_front(face:Tensor, faces, pad_width):
|
| 187 |
+
"""
|
| 188 |
+
对前视图进行边缘拼接:
|
| 189 |
+
上侧:拼接上视图的下边缘 (取 top[:, -w:, :])
|
| 190 |
+
下侧:拼接下视图的上边缘 (取 bottom[:, :w, :])
|
| 191 |
+
左侧:拼接左视图的右边缘 (取 left[:, :, -w:])
|
| 192 |
+
右侧:拼接右视图的左边缘 (取 right[:, :, :w])
|
| 193 |
+
"""
|
| 194 |
+
C, H, W = face.shape
|
| 195 |
+
H_new=H + 2*pad_width
|
| 196 |
+
W_new=W + 2*pad_width
|
| 197 |
+
padded = torch.zeros((C, H_new, W_new), dtype=face.dtype, device=face.device)
|
| 198 |
+
# 中间放置前视图
|
| 199 |
+
padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
|
| 200 |
+
# 上边缘
|
| 201 |
+
top_edge=faces['top'][:, -pad_width:, :]
|
| 202 |
+
padded[:, 0:pad_width, 0:W+2*pad_width] += stretch_edge_patch(top_edge,pad_width,"top")
|
| 203 |
+
|
| 204 |
+
# 下边缘
|
| 205 |
+
bottom_edge=faces['bottom'][:, :pad_width, :]
|
| 206 |
+
padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
|
| 207 |
+
|
| 208 |
+
# 左边缘
|
| 209 |
+
left_edge=faces['left'][:, :, -pad_width:]
|
| 210 |
+
padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
|
| 211 |
+
|
| 212 |
+
# 右边缘
|
| 213 |
+
right_edge=faces['right'][:, :, :pad_width]
|
| 214 |
+
padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
|
| 215 |
+
return padded
|
| 216 |
+
|
| 217 |
+
def pad_right(face, faces, pad_width):
|
| 218 |
+
"""
|
| 219 |
+
对右视图进行边缘拼接:
|
| 220 |
+
左侧:拼接前视图的右边缘 (front[:, :, -w:])
|
| 221 |
+
右侧:拼接后视图的左边缘 (back[:, :, :w])
|
| 222 |
+
上侧:拼接上视图的右边缘 (top[:, :, -w:])
|
| 223 |
+
下侧:拼接下视图的右边缘 (bottom[:, :, -w:])
|
| 224 |
+
"""
|
| 225 |
+
C, H, W = face.shape
|
| 226 |
+
H_new=H + 2*pad_width
|
| 227 |
+
W_new=W + 2*pad_width
|
| 228 |
+
padded = torch.zeros((C, H_new, W_new), dtype=face.dtype, device=face.device)
|
| 229 |
+
padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
|
| 230 |
+
|
| 231 |
+
left_edge=faces['front'][:, :, -pad_width:]
|
| 232 |
+
padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
|
| 233 |
+
|
| 234 |
+
right_edge=faces['back'][:, :, :pad_width]
|
| 235 |
+
padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
|
| 236 |
+
|
| 237 |
+
# 上侧:拼接上视图的右边缘,顺时针旋转90度
|
| 238 |
+
# 原始 top 边缘为 shape (C, H, w) ,旋转后变为 (C, w, H)
|
| 239 |
+
top_edge = torch.rot90(faces['top'][:, :, -pad_width:], k=3, dims=(1,2))
|
| 240 |
+
|
| 241 |
+
padded[:, 0:pad_width, :] += stretch_edge_patch(top_edge,pad_width,"top")
|
| 242 |
+
|
| 243 |
+
# 下侧:拼接下视图的右边缘,逆时针旋转90度
|
| 244 |
+
# 原始 bottom 边缘为 shape (C, H, w) ,旋转后变为 (C, w, H)
|
| 245 |
+
bottom_edge = torch.rot90(faces['bottom'][:, :, -pad_width:], k=1, dims=(1,2))
|
| 246 |
+
padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
|
| 247 |
+
return padded
|
| 248 |
+
|
| 249 |
+
def pad_back(face, faces, pad_width):
|
| 250 |
+
"""
|
| 251 |
+
对后视图进行边缘拼接:
|
| 252 |
+
左侧:拼接右视图的右边缘 (right[:, :, -w:])
|
| 253 |
+
右侧:拼接左视图的左边缘 (left[:, :, :w])
|
| 254 |
+
上侧:拼接上视图的上边缘 (top[:, :w, :])
|
| 255 |
+
下侧:拼接下视图的下边缘 (bottom[:, -w:, :])
|
| 256 |
+
"""
|
| 257 |
+
C, H, W = face.shape
|
| 258 |
+
H_new=H + 2*pad_width
|
| 259 |
+
W_new=W + 2*pad_width
|
| 260 |
+
padded = torch.zeros((C, H_new, W_new), dtype=face.dtype, device=face.device)
|
| 261 |
+
padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
|
| 262 |
+
|
| 263 |
+
left_edge=faces['right'][:, :, -pad_width:]
|
| 264 |
+
padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
|
| 265 |
+
|
| 266 |
+
right_edge=faces['left'][:, :, :pad_width]
|
| 267 |
+
padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
|
| 268 |
+
|
| 269 |
+
# 上侧:���用上视图的上边缘,并旋转180度
|
| 270 |
+
# 旋转180度可使用 torch.rot90(..., k=2, dims=(1,2))
|
| 271 |
+
top_edge = torch.rot90(faces['top'][:, :pad_width, :], k=2, dims=(1,2))
|
| 272 |
+
padded[:, 0:pad_width, :] +=stretch_edge_patch(top_edge,pad_width,"top")
|
| 273 |
+
|
| 274 |
+
# 下侧:使用下视图的下边缘,并旋转180度
|
| 275 |
+
bottom_edge = torch.rot90(faces['bottom'][:, -pad_width:, :], k=2, dims=(1,2))
|
| 276 |
+
padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
|
| 277 |
+
return padded
|
| 278 |
+
|
| 279 |
+
def pad_left(face, faces, pad_width):
|
| 280 |
+
"""
|
| 281 |
+
对左视图进行边缘拼接:
|
| 282 |
+
左侧:拼接后视图的右边缘 (back[:, :, -w:])
|
| 283 |
+
右侧:拼接前视图的左边缘 (front[:, :, :w])
|
| 284 |
+
上侧:拼接上视图的左边缘 (top[:, :, :w])
|
| 285 |
+
下侧:拼接下视图的左边缘 (bottom[:, :, :w])
|
| 286 |
+
"""
|
| 287 |
+
C, H, W = face.shape
|
| 288 |
+
padded = torch.zeros((C, H + 2*pad_width, W + 2*pad_width), dtype=face.dtype, device=face.device)
|
| 289 |
+
padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
|
| 290 |
+
|
| 291 |
+
left_edge=faces['back'][:, :, -pad_width:]
|
| 292 |
+
padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
|
| 293 |
+
|
| 294 |
+
right_edge=faces['front'][:, :, :pad_width]
|
| 295 |
+
padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
|
| 296 |
+
|
| 297 |
+
top_edge=torch.rot90(faces['top'][:, :, :pad_width],k=1,dims=(1,2))
|
| 298 |
+
padded[:, 0:pad_width, :] += stretch_edge_patch(top_edge,pad_width,"top")
|
| 299 |
+
|
| 300 |
+
bottom_edge=torch.rot90(faces['bottom'][:, :, :pad_width],k=3,dims=(1,2))
|
| 301 |
+
padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
|
| 302 |
+
return padded
|
| 303 |
+
|
| 304 |
+
def pad_top(face, faces, pad_width):
|
| 305 |
+
"""
|
| 306 |
+
对上视图进行边缘拼接:
|
| 307 |
+
下侧:拼接前视图的上边缘 (front[:, :w, :])
|
| 308 |
+
左侧:拼接左视图的上边缘 (left[:, :w, :])
|
| 309 |
+
右侧:拼接右视图的上边缘 (right[:, :w, :])
|
| 310 |
+
上侧:拼接后视图的上边缘 (back[:, :w, :])
|
| 311 |
+
"""
|
| 312 |
+
C, H, W = face.shape
|
| 313 |
+
padded = torch.zeros((C, H + 2*pad_width, W + 2*pad_width), dtype=face.dtype, device=face.device)
|
| 314 |
+
padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
|
| 315 |
+
|
| 316 |
+
bottom_edge=faces['front'][:, :pad_width, :]
|
| 317 |
+
padded[:, H+pad_width:H+2*pad_width, :] +=stretch_edge_patch(bottom_edge,pad_width,"bottom")
|
| 318 |
+
|
| 319 |
+
left_edge=torch.rot90(faces['left'][:, :pad_width, :],k=3,dims=(1,2))
|
| 320 |
+
padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
|
| 321 |
+
|
| 322 |
+
right_edge=torch.rot90(faces['right'][:, :pad_width, :],k=1,dims=(1,2))
|
| 323 |
+
padded[:, :, W+pad_width:W+2*pad_width]+=stretch_edge_patch(right_edge,pad_width,"right")
|
| 324 |
+
|
| 325 |
+
top_edge=torch.rot90(faces['back'][:, :pad_width, :], k=2, dims=(1,2))
|
| 326 |
+
|
| 327 |
+
padded[:, 0:pad_width, :] +=stretch_edge_patch(top_edge,pad_width,"top")
|
| 328 |
+
return padded
|
| 329 |
+
|
| 330 |
+
def pad_bottom(face, faces, pad_width):
|
| 331 |
+
"""
|
| 332 |
+
对下视图进行边缘拼接:
|
| 333 |
+
上侧:拼接前视图的下边缘 (front[:, -w:, :])
|
| 334 |
+
左侧:拼接左视图的下边缘 (left[:, -w:, :])
|
| 335 |
+
右侧:拼接右视图的下边缘 (right[:, -w:, :])
|
| 336 |
+
下侧:拼接后视图的下边缘 (back[:, :-w, :])
|
| 337 |
+
"""
|
| 338 |
+
C, H, W = face.shape
|
| 339 |
+
padded = torch.zeros((C, H + 2*pad_width, W + 2*pad_width), dtype=face.dtype, device=face.device)
|
| 340 |
+
padded[:, pad_width:pad_width+H, pad_width:pad_width+W] = face
|
| 341 |
+
|
| 342 |
+
top_edge=faces['front'][:, -pad_width:, :]
|
| 343 |
+
padded[:, 0:pad_width, :] += stretch_edge_patch(top_edge,pad_width,"top")
|
| 344 |
+
|
| 345 |
+
left_edge=torch.rot90(faces['left'][:, -pad_width:, :],k=1,dims=(1,2))
|
| 346 |
+
padded[:, :, 0:pad_width] += stretch_edge_patch(left_edge,pad_width,"left")
|
| 347 |
+
|
| 348 |
+
right_edge=torch.rot90(faces['right'][:, -pad_width:, :],k=3,dims=(1,2))
|
| 349 |
+
padded[:, :, W+pad_width:W+2*pad_width] += stretch_edge_patch(right_edge,pad_width,"right")
|
| 350 |
+
|
| 351 |
+
bottom_edge=torch.rot90(faces['back'][:, -pad_width:, :],k=2,dims=(1,2))
|
| 352 |
+
padded[:, H+pad_width:H+2*pad_width, :] += stretch_edge_patch(bottom_edge,pad_width,"bottom")
|
| 353 |
+
return padded
|
| 354 |
+
|
| 355 |
+
pad_funcs = {
|
| 356 |
+
"front": pad_front,
|
| 357 |
+
"right": pad_right,
|
| 358 |
+
"back": pad_back,
|
| 359 |
+
"left": pad_left,
|
| 360 |
+
"top": pad_top,
|
| 361 |
+
"bottom": pad_bottom,
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
def pad_face(faces: dict, width: int, face_name: str)->Tensor:
|
| 365 |
+
"""
|
| 366 |
+
根据 face_name 调用对应的拼接函数
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
if face_name not in pad_funcs:
|
| 370 |
+
raise ValueError(f"Invalid face name: {face_name}. Must be one of {list(pad_funcs.keys())}.")
|
| 371 |
+
return pad_funcs[face_name](faces[face_name], faces, width)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def prepare_mask(image,facename):
|
| 375 |
+
"""
|
| 376 |
+
根据 facename 为每张图生成对应的 mask。
|
| 377 |
+
如果 facename 为 "front",mask 全部置为 1,其它置为 0。
|
| 378 |
+
生成的 mask 形状为 (1, H, W),即与图像的高度和宽度一致,但只有 1 个通道。
|
| 379 |
+
|
| 380 |
+
参数:
|
| 381 |
+
image (torch.Tensor): 图像 tensor,形状应为 (C, H, W) 或者 (N, C, H, W) 中的单张图像
|
| 382 |
+
facename (str): 表示图像对应的面名称,例如 "front", "back" 等
|
| 383 |
+
|
| 384 |
+
返回:
|
| 385 |
+
torch.Tensor: 生成的 mask,形状为 (1, H, W)
|
| 386 |
+
"""
|
| 387 |
+
# 如果 image 是 (C, H, W),那么 H=image.shape[1], W=image.shape[2]
|
| 388 |
+
# 如果 image 是 (N, C, H, W),可以使用 image[0] 取得一张图像的尺寸
|
| 389 |
+
if image.ndim == 3:
|
| 390 |
+
H, W = image.shape[1], image.shape[2]
|
| 391 |
+
elif image.ndim == 4:
|
| 392 |
+
H, W = image.shape[2], image.shape[3]
|
| 393 |
+
else:
|
| 394 |
+
raise ValueError("Unsupported image shape")
|
| 395 |
+
|
| 396 |
+
mask_shape = (1, H, W)
|
| 397 |
+
if facename == "front":
|
| 398 |
+
return torch.zeros(mask_shape, dtype=image.dtype, device=image.device)
|
| 399 |
+
else:
|
| 400 |
+
return torch.ones(mask_shape, dtype=image.dtype, device=image.device)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def generate_cubemap_uv(H, W):
|
| 404 |
+
""" 生成 cube face 上每个点的 3D 归一化坐标 (x, y, z) 并计算 UV 映射 """
|
| 405 |
+
|
| 406 |
+
H=int(H)
|
| 407 |
+
W=int(W)
|
| 408 |
+
|
| 409 |
+
# 生成 [-1,1] 范围的 grid(cube face 上的 x, y 坐标)
|
| 410 |
+
u_range = torch.linspace(-1, 1, W).view(1, -1).expand(H, -1) # HxW
|
| 411 |
+
v_range = torch.linspace(-1, 1, H).view(-1, 1).expand(-1, W) # HxW
|
| 412 |
+
|
| 413 |
+
# 设定六个面 (x, y, z) 归一化坐标
|
| 414 |
+
faces = {
|
| 415 |
+
"front": (u_range, v_range, torch.ones_like(u_range)), # (x, y, z=1)
|
| 416 |
+
"back": (-u_range, v_range, -torch.ones_like(u_range)), # (-x, y, z=-1)
|
| 417 |
+
"left": (-torch.ones_like(u_range), v_range, u_range), # (-1, y, -x)
|
| 418 |
+
"right": (torch.ones_like(u_range), v_range, -u_range), # (1, y, x)
|
| 419 |
+
"top": (u_range, -torch.ones_like(u_range), v_range), # (x, 1, y)
|
| 420 |
+
"bottom": (u_range, torch.ones_like(u_range), -v_range), # (x, -1, -y)
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
# 计算六个面的 UV
|
| 424 |
+
uv_faces = {}
|
| 425 |
+
for face, (x, y, z) in faces.items():
|
| 426 |
+
u = torch.atan2(x, z)/(2*torch.pi)+0.5
|
| 427 |
+
v = torch.atan2(y, torch.sqrt(x ** 2 + z ** 2))/(2*torch.pi)+0.5
|
| 428 |
+
uv_faces[face] = torch.stack([u,v], dim=0) # shape: (2, H, W)
|
| 429 |
+
|
| 430 |
+
return uv_faces # 返回每个面的 UV 坐标
|
| 431 |
+
|
| 432 |
+
import torch
|
| 433 |
+
|
| 434 |
+
def generate_cubemap_uv_padding(H, W, padding_pixels=0):
|
| 435 |
+
""" 生成 cube face 上每个点的 3D 归一化坐标 (x, y, z) 并计算 UV 映射,支持自定义 padding """
|
| 436 |
+
|
| 437 |
+
H = int(H)
|
| 438 |
+
W = int(W)
|
| 439 |
+
|
| 440 |
+
# 计算 padding 的比例
|
| 441 |
+
padding_ratio = padding_pixels / W # 例如 50 / 512 ≈ 0.0977
|
| 442 |
+
|
| 443 |
+
# 计算扩展后的尺寸
|
| 444 |
+
H_new = H + 2 * padding_pixels
|
| 445 |
+
W_new = W + 2 * padding_pixels
|
| 446 |
+
|
| 447 |
+
# 生成扩展范围的 grid(从 [-1-padding_ratio, 1+padding_ratio])
|
| 448 |
+
u_range = torch.linspace(-1 - padding_ratio, 1 + padding_ratio, W_new).view(1, -1).expand(H_new, -1)
|
| 449 |
+
v_range = torch.linspace(-1 - padding_ratio, 1 + padding_ratio, H_new).view(-1, 1).expand(-1, W_new)
|
| 450 |
+
|
| 451 |
+
# 定义六个面的 3D 归一化坐标
|
| 452 |
+
faces = {
|
| 453 |
+
"front": (u_range, v_range, torch.ones_like(u_range)),
|
| 454 |
+
"back": (-u_range, v_range, -torch.ones_like(u_range)),
|
| 455 |
+
"left": (-torch.ones_like(u_range), v_range, u_range),
|
| 456 |
+
"right": (torch.ones_like(u_range), v_range, -u_range),
|
| 457 |
+
"top": (u_range, -torch.ones_like(u_range), v_range),
|
| 458 |
+
"bottom": (u_range, torch.ones_like(u_range), -v_range),
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
# 计算六个面的 UV
|
| 462 |
+
uv_faces = {}
|
| 463 |
+
for face, (x, y, z) in faces.items():
|
| 464 |
+
u = torch.atan2(x, z) / (2 * torch.pi) + 0.5
|
| 465 |
+
v = torch.atan2(y, torch.sqrt(x ** 2 + z ** 2)) / (2 * torch.pi) + 0.5
|
| 466 |
+
uv = torch.stack([u, v], dim=0) # shape: (2, H_new, W_new)
|
| 467 |
+
# 使用双线性插值将 UV resize 回 (2, H, W)
|
| 468 |
+
uv_resized = F.interpolate(uv.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=True).squeeze(0)
|
| 469 |
+
|
| 470 |
+
uv_faces[face] = uv_resized
|
| 471 |
+
|
| 472 |
+
return uv_faces
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def merge_uv_with_latent(latent, uv_maps,dim=1):
|
| 476 |
+
# 调整 uv_maps 的大小,使其与 latent 的空间尺寸一致
|
| 477 |
+
# 注意:这里采用双线性插值,并设置 align_corners=False
|
| 478 |
+
uv_maps_resized = F.interpolate(uv_maps, size=latent.shape[-2:], mode="bilinear", align_corners=False)
|
| 479 |
+
|
| 480 |
+
# 在通道维度上拼接,即 dim=1
|
| 481 |
+
latent_with_uv = torch.cat([latent, uv_maps_resized], dim=dim)
|
| 482 |
+
return latent_with_uv
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def resize_and_crop(image: np.ndarray, padding: int) -> np.ndarray:
|
| 487 |
+
"""
|
| 488 |
+
先将输入的图片 resize 到 (H + padding * 2, W + padding * 2),
|
| 489 |
+
然后再剪裁掉外侧四个边缘各 padding 宽度,恢复到原来的 H, W。
|
| 490 |
+
|
| 491 |
+
参数:
|
| 492 |
+
image (np.ndarray): 输入的图片,形状为 (H, W, C)。
|
| 493 |
+
padding (int): 需要添加的边界宽度。
|
| 494 |
+
|
| 495 |
+
返回:
|
| 496 |
+
np.ndarray: 处理后的图片,形状仍为 (H, W, C)。
|
| 497 |
+
"""
|
| 498 |
+
if not isinstance(image, np.ndarray):
|
| 499 |
+
raise ValueError("输入图片必须是 numpy 数组格式")
|
| 500 |
+
|
| 501 |
+
H, W = image.shape[:2] # 获取原始尺寸
|
| 502 |
+
|
| 503 |
+
# Step 1: Resize 到 (H + padding * 2, W + padding * 2)
|
| 504 |
+
resized_image = cv2.resize(image, (W + 2 * padding, H + 2 * padding), interpolation=cv2.INTER_LINEAR)
|
| 505 |
+
|
| 506 |
+
# Step 2: 裁剪掉外�� padding 的宽度,恢复到原来的 (H, W)
|
| 507 |
+
cropped_image = resized_image[padding:H + padding, padding:W + padding]
|
| 508 |
+
|
| 509 |
+
return cropped_image # 返回 numpy.ndarray
|
| 510 |
+
|
| 511 |
+
def cubemap_unfold(cubemaps,H:int=512,W:int=512,channels:int=3,transparent:bool=False)->Image.Image:
|
| 512 |
+
# 拼接成 3x4 的布局
|
| 513 |
+
# 整体画布尺寸:3 行,每行 H 像素;4 列,每列 W 像素
|
| 514 |
+
canvas_H = 3 * H
|
| 515 |
+
canvas_W = 4 * W
|
| 516 |
+
|
| 517 |
+
num_channels=channels if transparent==False else channels+1
|
| 518 |
+
# 确保 canvas 也是正确的形状
|
| 519 |
+
canvas = np.zeros(shape=(canvas_H, canvas_W, num_channels), dtype=cubemaps[0].dtype)
|
| 520 |
+
|
| 521 |
+
if channels==1:
|
| 522 |
+
canvas=np.squeeze(canvas, axis=-1)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
face_imgs = {face: cubemaps[i] for i, face in enumerate(FACES)}
|
| 526 |
+
|
| 527 |
+
alpha_layer=num_channels-1
|
| 528 |
+
|
| 529 |
+
# 布局安排(以 0 为起始索引):
|
| 530 |
+
# 第一行:只在 (0,1) 位置放 top
|
| 531 |
+
# 第二行:依次为 left, front, right, back(对应列 0,1,2,3)
|
| 532 |
+
# 第三行:只在 (2,1) 位置放 bottom
|
| 533 |
+
|
| 534 |
+
# 将 top 放在第一行第二列
|
| 535 |
+
row, col = 0, 1
|
| 536 |
+
if channels==1:
|
| 537 |
+
canvas[row*H:(row+1)*H, col*W:(col+1)*W,0] = face_imgs['top']
|
| 538 |
+
else:
|
| 539 |
+
canvas[row*H:(row+1)*H, col*W:(col+1)*W,:channels] = face_imgs['top']
|
| 540 |
+
|
| 541 |
+
if transparent:
|
| 542 |
+
canvas[row * H:(row + 1) * H, col * W:(col + 1) * W, alpha_layer] = 255 # Set alpha to opaque
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# 将 left, front, right, back 分别放在第二行(行索引 1)从列 0 到 3
|
| 546 |
+
row = 1
|
| 547 |
+
for i, face in enumerate(['left', 'front', 'right', 'back']):
|
| 548 |
+
col = i # 分别放在第 0,1,2,3 列
|
| 549 |
+
if channels==1:
|
| 550 |
+
canvas[row*H:(row+1)*H, col*W:(col+1)*W,0] = face_imgs[face]
|
| 551 |
+
else:
|
| 552 |
+
canvas[row*H:(row+1)*H, col*W:(col+1)*W,:channels] = face_imgs[face]
|
| 553 |
+
|
| 554 |
+
if transparent:
|
| 555 |
+
canvas[row * H:(row + 1) * H, col * W:(col + 1) * W, alpha_layer] = 255
|
| 556 |
+
|
| 557 |
+
# 将 bottom 放在第三行第二列
|
| 558 |
+
row, col = 2, 1
|
| 559 |
+
|
| 560 |
+
if channels==1:
|
| 561 |
+
canvas[row*H:(row+1)*H, col*W:(col+1)*W,0] = face_imgs['bottom']
|
| 562 |
+
else:
|
| 563 |
+
canvas[row*H:(row+1)*H, col*W:(col+1)*W,:channels] = face_imgs['bottom']
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
if transparent:
|
| 567 |
+
canvas[row * H:(row + 1) * H, col * W:(col + 1) * W, alpha_layer] = 255 # Set alpha to opaque
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
if channels==1:
|
| 571 |
+
return Image.fromarray(canvas,mode="L")
|
| 572 |
+
|
| 573 |
+
if np.issubdtype(canvas.dtype, np.floating):
|
| 574 |
+
canvas = np.clip(canvas * 255, 0, 255).astype(np.uint8)
|
| 575 |
+
|
| 576 |
+
return Image.fromarray(canvas)
|
viewer.html
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html>
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8">
|
| 5 |
+
<title>全景图查看器</title>
|
| 6 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/pannellum@2.5.6/build/pannellum.css"/>
|
| 7 |
+
<script type="text/javascript" src="https://cdn.jsdelivr.net/npm/pannellum@2.5.6/build/pannellum.js"></script>
|
| 8 |
+
<style>
|
| 9 |
+
html, body {
|
| 10 |
+
margin: 0;
|
| 11 |
+
padding: 0;
|
| 12 |
+
width: 100%;
|
| 13 |
+
height: 100%;
|
| 14 |
+
overflow: hidden;
|
| 15 |
+
}
|
| 16 |
+
#panorama {
|
| 17 |
+
width: 100%;
|
| 18 |
+
height: 100%;
|
| 19 |
+
}
|
| 20 |
+
.pnlm-load-button {
|
| 21 |
+
display: none !important;
|
| 22 |
+
}
|
| 23 |
+
</style>
|
| 24 |
+
</head>
|
| 25 |
+
<body>
|
| 26 |
+
<div id="panorama"></div>
|
| 27 |
+
<script>
|
| 28 |
+
// 监听来自父窗口的消息
|
| 29 |
+
window.addEventListener('message', function(event) {
|
| 30 |
+
if (event.data.type === 'loadPanorama') {
|
| 31 |
+
console.log('Received image URL:', event.data.image);
|
| 32 |
+
|
| 33 |
+
// 销毁现有的查看器(如果存在)
|
| 34 |
+
if (window.viewer) {
|
| 35 |
+
window.viewer.destroy();
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// 创建新的查看器
|
| 39 |
+
window.viewer = pannellum.viewer('panorama', {
|
| 40 |
+
type: 'equirectangular',
|
| 41 |
+
panorama: event.data.image,
|
| 42 |
+
autoLoad: true,
|
| 43 |
+
autoRotate: -2,
|
| 44 |
+
compass: true,
|
| 45 |
+
northOffset: 0,
|
| 46 |
+
showFullscreenCtrl: true,
|
| 47 |
+
showControls: true,
|
| 48 |
+
mouseZoom: true,
|
| 49 |
+
draggable: true,
|
| 50 |
+
friction: 0.2,
|
| 51 |
+
minHfov: 50,
|
| 52 |
+
maxHfov: 120,
|
| 53 |
+
hfov: 100,
|
| 54 |
+
onLoad: function() {
|
| 55 |
+
console.log('Panorama loaded successfully');
|
| 56 |
+
},
|
| 57 |
+
onError: function(error) {
|
| 58 |
+
console.error('Error loading panorama:', error);
|
| 59 |
+
}
|
| 60 |
+
});
|
| 61 |
+
}
|
| 62 |
+
});
|
| 63 |
+
</script>
|
| 64 |
+
</body>
|
| 65 |
+
</html>
|