schrum2 commited on
Commit
4e6b80a
·
verified ·
1 Parent(s): 02aa4f3

moved back to models

Browse files
Files changed (1) hide show
  1. latent_diffusion_pipeline.py +0 -99
latent_diffusion_pipeline.py DELETED
@@ -1,99 +0,0 @@
1
- from diffusers import DDPMPipeline
2
- import torch
3
- import torch.nn.functional as F
4
- from typing import Optional, Union, List, Tuple
5
- from diffusers.utils.torch_utils import randn_tensor
6
- from diffusers.pipelines.ddpm.pipeline_ddpm import ImagePipelineOutput
7
- import common_settings as common_settings
8
- import os
9
- import json
10
- from general_training_helper import get_scene_from_embeddings
11
-
12
- class UnconditionalDDPMPipeline(DDPMPipeline):
13
- def __init__(self, unet, scheduler, block_embeddings=None):
14
- super().__init__(unet, scheduler)
15
-
16
- self.block_embeddings = block_embeddings
17
-
18
-
19
- def save_pretrained(self, save_directory):
20
- os.makedirs(save_directory, exist_ok=True)
21
- super().save_pretrained(save_directory)
22
- # Save block_embeddings tensor if it exists
23
- if self.block_embeddings is not None:
24
- torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
25
-
26
- @classmethod
27
- def from_pretrained(cls, pretrained_model_path, **kwargs):
28
- pipeline = super().from_pretrained(pretrained_model_path, **kwargs)
29
- # Load block_embeddings tensor if it exists
30
- block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
31
- if os.path.exists(block_embeds_path):
32
- pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
33
- else:
34
- pipeline.block_embeddings = None
35
- return pipeline
36
-
37
-
38
-
39
- def give_sprite_scaling_factors(self, sprite_scaling_factors):
40
- """
41
- Set the sprite scaling factors for the pipeline.
42
- This is used to apply per-sprite temperature scaling during inference.
43
- """
44
- self.sprite_scaling_factors = sprite_scaling_factors
45
-
46
- def __call__(
47
- self,
48
- batch_size: int = 1,
49
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
50
- num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
51
- output_type: Optional[str] = "tensor",
52
- return_dict: bool = True,
53
- height: int = common_settings.MARIO_HEIGHT, width: int = common_settings.MARIO_WIDTH,
54
- latents: Optional[torch.FloatTensor] = None,
55
- show_progress_bar=True,
56
- ) -> Union[ImagePipelineOutput, Tuple]:
57
-
58
- self.unet.eval()
59
- with torch.no_grad():
60
-
61
- if latents is not None:
62
- image = latents.to(self.device)
63
- else:
64
- image_shape = (
65
- batch_size,
66
- self.unet.config.in_channels,
67
- height,
68
- width
69
- )
70
-
71
- image = torch.randn(image_shape, generator=generator, device=self.device)
72
-
73
- self.scheduler.set_timesteps(num_inference_steps)
74
-
75
- iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
76
- for t in iterator:
77
- #print(image.shape)
78
- model_output = self.unet(image, t).sample
79
- image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
80
-
81
- # Apply per-sprite temperature scaling if enabled
82
- if hasattr(self,"sprite_scaling_factors") and self.sprite_scaling_factors is not None:
83
- image = image / self.sprite_scaling_factors.view(1, -1, 1, 1)
84
-
85
-
86
- if self.block_embeddings is not None:
87
- image = get_scene_from_embeddings(image, self.block_embeddings)
88
- else:
89
- image = F.softmax(image, dim=1)
90
- image = image.detach().cpu()
91
-
92
- if not return_dict:
93
- return (image,)
94
-
95
- return ImagePipelineOutput(images=image)
96
-
97
- def print_unet_architecture(self):
98
- """Prints the architecture of the UNet model."""
99
- print(self.unet)