MarioDiffusion-MLM-regular0 / latent_diffusion_pipeline.py
schrum2's picture
Loading into root will supposedly make them easier to find
a09cfc1 verified
raw
history blame
3.94 kB
from diffusers import DDPMPipeline
import torch
import torch.nn.functional as F
from typing import Optional, Union, List, Tuple
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.ddpm.pipeline_ddpm import ImagePipelineOutput
import common_settings as common_settings
import os
import json
from general_training_helper import get_scene_from_embeddings
class UnconditionalDDPMPipeline(DDPMPipeline):
def __init__(self, unet, scheduler, block_embeddings=None):
super().__init__(unet, scheduler)
self.block_embeddings = block_embeddings
def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
super().save_pretrained(save_directory)
# Save block_embeddings tensor if it exists
if self.block_embeddings is not None:
torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
@classmethod
def from_pretrained(cls, pretrained_model_path, **kwargs):
pipeline = super().from_pretrained(pretrained_model_path, **kwargs)
# Load block_embeddings tensor if it exists
block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
if os.path.exists(block_embeds_path):
pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
else:
pipeline.block_embeddings = None
return pipeline
def give_sprite_scaling_factors(self, sprite_scaling_factors):
"""
Set the sprite scaling factors for the pipeline.
This is used to apply per-sprite temperature scaling during inference.
"""
self.sprite_scaling_factors = sprite_scaling_factors
def __call__(
self,
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
output_type: Optional[str] = "tensor",
return_dict: bool = True,
height: int = common_settings.MARIO_HEIGHT, width: int = common_settings.MARIO_WIDTH,
latents: Optional[torch.FloatTensor] = None,
show_progress_bar=True,
) -> Union[ImagePipelineOutput, Tuple]:
self.unet.eval()
with torch.no_grad():
if latents is not None:
image = latents.to(self.device)
else:
image_shape = (
batch_size,
self.unet.config.in_channels,
height,
width
)
image = torch.randn(image_shape, generator=generator, device=self.device)
self.scheduler.set_timesteps(num_inference_steps)
iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
for t in iterator:
#print(image.shape)
model_output = self.unet(image, t).sample
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
# Apply per-sprite temperature scaling if enabled
if hasattr(self,"sprite_scaling_factors") and self.sprite_scaling_factors is not None:
image = image / self.sprite_scaling_factors.view(1, -1, 1, 1)
if self.block_embeddings is not None:
image = get_scene_from_embeddings(image, self.block_embeddings)
else:
image = F.softmax(image, dim=1)
image = image.detach().cpu()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def print_unet_architecture(self):
"""Prints the architecture of the UNet model."""
print(self.unet)