moved back to models
Browse files- latent_diffusion_pipeline.py +0 -99
latent_diffusion_pipeline.py
DELETED
|
@@ -1,99 +0,0 @@
|
|
| 1 |
-
from diffusers import DDPMPipeline
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from typing import Optional, Union, List, Tuple
|
| 5 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 6 |
-
from diffusers.pipelines.ddpm.pipeline_ddpm import ImagePipelineOutput
|
| 7 |
-
import common_settings as common_settings
|
| 8 |
-
import os
|
| 9 |
-
import json
|
| 10 |
-
from general_training_helper import get_scene_from_embeddings
|
| 11 |
-
|
| 12 |
-
class UnconditionalDDPMPipeline(DDPMPipeline):
|
| 13 |
-
def __init__(self, unet, scheduler, block_embeddings=None):
|
| 14 |
-
super().__init__(unet, scheduler)
|
| 15 |
-
|
| 16 |
-
self.block_embeddings = block_embeddings
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def save_pretrained(self, save_directory):
|
| 20 |
-
os.makedirs(save_directory, exist_ok=True)
|
| 21 |
-
super().save_pretrained(save_directory)
|
| 22 |
-
# Save block_embeddings tensor if it exists
|
| 23 |
-
if self.block_embeddings is not None:
|
| 24 |
-
torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
|
| 25 |
-
|
| 26 |
-
@classmethod
|
| 27 |
-
def from_pretrained(cls, pretrained_model_path, **kwargs):
|
| 28 |
-
pipeline = super().from_pretrained(pretrained_model_path, **kwargs)
|
| 29 |
-
# Load block_embeddings tensor if it exists
|
| 30 |
-
block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
|
| 31 |
-
if os.path.exists(block_embeds_path):
|
| 32 |
-
pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
|
| 33 |
-
else:
|
| 34 |
-
pipeline.block_embeddings = None
|
| 35 |
-
return pipeline
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def give_sprite_scaling_factors(self, sprite_scaling_factors):
|
| 40 |
-
"""
|
| 41 |
-
Set the sprite scaling factors for the pipeline.
|
| 42 |
-
This is used to apply per-sprite temperature scaling during inference.
|
| 43 |
-
"""
|
| 44 |
-
self.sprite_scaling_factors = sprite_scaling_factors
|
| 45 |
-
|
| 46 |
-
def __call__(
|
| 47 |
-
self,
|
| 48 |
-
batch_size: int = 1,
|
| 49 |
-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 50 |
-
num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
|
| 51 |
-
output_type: Optional[str] = "tensor",
|
| 52 |
-
return_dict: bool = True,
|
| 53 |
-
height: int = common_settings.MARIO_HEIGHT, width: int = common_settings.MARIO_WIDTH,
|
| 54 |
-
latents: Optional[torch.FloatTensor] = None,
|
| 55 |
-
show_progress_bar=True,
|
| 56 |
-
) -> Union[ImagePipelineOutput, Tuple]:
|
| 57 |
-
|
| 58 |
-
self.unet.eval()
|
| 59 |
-
with torch.no_grad():
|
| 60 |
-
|
| 61 |
-
if latents is not None:
|
| 62 |
-
image = latents.to(self.device)
|
| 63 |
-
else:
|
| 64 |
-
image_shape = (
|
| 65 |
-
batch_size,
|
| 66 |
-
self.unet.config.in_channels,
|
| 67 |
-
height,
|
| 68 |
-
width
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
image = torch.randn(image_shape, generator=generator, device=self.device)
|
| 72 |
-
|
| 73 |
-
self.scheduler.set_timesteps(num_inference_steps)
|
| 74 |
-
|
| 75 |
-
iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
|
| 76 |
-
for t in iterator:
|
| 77 |
-
#print(image.shape)
|
| 78 |
-
model_output = self.unet(image, t).sample
|
| 79 |
-
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
|
| 80 |
-
|
| 81 |
-
# Apply per-sprite temperature scaling if enabled
|
| 82 |
-
if hasattr(self,"sprite_scaling_factors") and self.sprite_scaling_factors is not None:
|
| 83 |
-
image = image / self.sprite_scaling_factors.view(1, -1, 1, 1)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
if self.block_embeddings is not None:
|
| 87 |
-
image = get_scene_from_embeddings(image, self.block_embeddings)
|
| 88 |
-
else:
|
| 89 |
-
image = F.softmax(image, dim=1)
|
| 90 |
-
image = image.detach().cpu()
|
| 91 |
-
|
| 92 |
-
if not return_dict:
|
| 93 |
-
return (image,)
|
| 94 |
-
|
| 95 |
-
return ImagePipelineOutput(images=image)
|
| 96 |
-
|
| 97 |
-
def print_unet_architecture(self):
|
| 98 |
-
"""Prints the architecture of the UNet model."""
|
| 99 |
-
print(self.unet)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|