Jasmeet Singh commited on
Commit
5c5178a
·
verified ·
1 Parent(s): 81777a8

Update generationPipeline.py

Browse files
Files changed (1) hide show
  1. generationPipeline.py +172 -172
generationPipeline.py CHANGED
@@ -1,173 +1,173 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- from sampler import DDPMSampler
5
- from tqdm import tqdm
6
-
7
-
8
- WIDTH = 512
9
- HEIGHT = 512
10
- LATENTS_WIDTH = WIDTH // 8
11
- LATENTS_HEIGHT = HEIGHT // 8
12
-
13
- def generate(
14
- prompt,
15
- uncond_prompt=None,
16
- input_image=None,
17
- strength=0.8,
18
- do_cfg=True,
19
- cfg_scale=7.5,
20
- sampler_name="ddpm",
21
- n_inference_steps=50,
22
- models={},
23
- seed=None,
24
- device=None,
25
- idle_device=None,
26
- tokenizer=None,
27
- ):
28
- with torch.no_grad():
29
- if not 0 < strength <= 1:
30
- raise ValueError("strength must be between 0 and 1")
31
-
32
- if idle_device:
33
- to_idle = lambda x: x.to(idle_device)
34
- else:
35
- to_idle = lambda x: x
36
-
37
- # Initialize random number generator according to the seed specified
38
- generator = torch.Generator(device=device)
39
- if seed is None:
40
- generator.seed()
41
- else:
42
- generator.manual_seed(seed)
43
-
44
- clip = models["clip"]
45
- clip.to(device)
46
-
47
- if do_cfg:
48
- # Convert into a list of length Seq_Len=77
49
- cond_tokens = tokenizer.batch_encode_plus(
50
- [prompt], padding="max_length", max_length=77
51
- ).input_ids
52
- # (Batch_Size, Seq_Len)
53
- cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
54
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
55
- cond_context = clip(cond_tokens)
56
- # Convert into a list of length Seq_Len=77
57
- uncond_tokens = tokenizer.batch_encode_plus(
58
- [uncond_prompt], padding="max_length", max_length=77
59
- ).input_ids
60
- # (Batch_Size, Seq_Len)
61
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
62
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
63
- uncond_context = clip(uncond_tokens)
64
- # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
65
- context = torch.cat([cond_context, uncond_context])
66
- else:
67
- # Convert into a list of length Seq_Len=77
68
- tokens = tokenizer.batch_encode_plus(
69
- [prompt], padding="max_length", max_length=77
70
- ).input_ids
71
- # (Batch_Size, Seq_Len)
72
- tokens = torch.tensor(tokens, dtype=torch.long, device=device)
73
- # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
74
- context = clip(tokens)
75
- to_idle(clip)
76
-
77
- if sampler_name == "ddpm":
78
- sampler = DDPMSampler(generator)
79
- sampler.set_inference_timesteps(n_inference_steps)
80
- else:
81
- raise ValueError("Unknown sampler value %s. ")
82
-
83
- latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
84
-
85
- if input_image:
86
- encoder = models["encoder"]
87
- encoder.to(device)
88
-
89
- input_image_tensor = input_image.resize((WIDTH, HEIGHT))
90
- # (Height, Width, Channel)
91
- input_image_tensor = np.array(input_image_tensor)
92
- # (Height, Width, Channel) -> (Height, Width, Channel)
93
- input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
94
- # (Height, Width, Channel) -> (Height, Width, Channel)
95
- input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
96
- # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
97
- input_image_tensor = input_image_tensor.unsqueeze(0)
98
- # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
99
- input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
100
-
101
- # (Batch_Size, 4, Latents_Height, Latents_Width)
102
- encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
103
- # (Batch_Size, 4, Latents_Height, Latents_Width)
104
- latents = encoder(input_image_tensor, encoder_noise)
105
-
106
- # Add noise to the latents (the encoded input image)
107
- # (Batch_Size, 4, Latents_Height, Latents_Width)
108
- sampler.set_strength(strength=strength)
109
- latents = sampler.add_noise(latents, sampler.timesteps[0])
110
-
111
- to_idle(encoder)
112
- else:
113
- # (Batch_Size, 4, Latents_Height, Latents_Width)
114
- latents = torch.randn(latents_shape, generator=generator, device=device)
115
-
116
- diffusion = models["diffusion"]
117
- diffusion.to(device)
118
-
119
- timesteps = tqdm(sampler.timesteps)
120
- for i, timestep in enumerate(timesteps):
121
- # (1, 320)
122
- time_embedding = get_time_embedding(timestep).to(device)
123
-
124
- # (Batch_Size, 4, Latents_Height, Latents_Width)
125
- model_input = latents
126
-
127
- if do_cfg:
128
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
129
- model_input = model_input.repeat(2, 1, 1, 1)
130
-
131
- # model_output is the predicted noise
132
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
133
- model_output = diffusion(model_input, context, time_embedding)
134
-
135
- if do_cfg:
136
- output_cond, output_uncond = model_output.chunk(2)
137
- model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
138
-
139
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
140
- latents = sampler.step(timestep, latents, model_output)
141
-
142
-
143
- to_idle(diffusion)
144
-
145
- decoder = models["decoder"]
146
- decoder.to(device)
147
- # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
148
- images = decoder(latents)
149
- to_idle(decoder)
150
-
151
- images = rescale(images, (-1, 1), (0, 255), clamp=True)
152
- # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
153
- images = images.permute(0, 2, 3, 1)
154
- images = images.to("cpu", torch.uint8).numpy()
155
- return images[0]
156
-
157
- def rescale(x, old_range, new_range, clamp=False):
158
- old_min, old_max = old_range
159
- new_min, new_max = new_range
160
- x -= old_min
161
- x *= (new_max - new_min) / (old_max - old_min)
162
- x += new_min
163
- if clamp:
164
- x = x.clamp(new_min, new_max)
165
- return x
166
-
167
- def get_time_embedding(timestep):
168
- # Shape: (160,)
169
- freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
170
- # Shape: (1, 160)
171
- x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
172
- # Shape: (1, 160 * 2)
173
  return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from sampler import DDPMSampler
5
+ from tqdm import tqdm
6
+
7
+
8
+ WIDTH = 512
9
+ HEIGHT = 512
10
+ LATENTS_WIDTH = WIDTH // 8
11
+ LATENTS_HEIGHT = HEIGHT // 8
12
+
13
+ def generate(
14
+ prompt,
15
+ uncond_prompt=None,
16
+ input_image=None,
17
+ strength=0.8,
18
+ do_cfg=True,
19
+ cfg_scale=7.5,
20
+ sampler_name="ddpm",
21
+ n_inference_steps=50,
22
+ models={},
23
+ seed=None,
24
+ device=None,
25
+ idle_device=None,
26
+ tokenizer=None,
27
+ ):
28
+ with torch.no_grad():
29
+ if not 0 < strength <= 1:
30
+ raise ValueError("strength must be between 0 and 1")
31
+
32
+ if idle_device:
33
+ to_idle = lambda x: x.to(idle_device)
34
+ else:
35
+ to_idle = lambda x: x
36
+
37
+ # Initialize random number generator according to the seed specified
38
+ generator = torch.Generator(device=device)
39
+ if seed is None:
40
+ generator.seed()
41
+ else:
42
+ generator.manual_seed(seed)
43
+
44
+ clip = models["clip"]
45
+ clip.to(device)
46
+
47
+ if do_cfg:
48
+ # Convert into a list of length Seq_Len=77
49
+ cond_tokens = tokenizer.batch_encode_plus(
50
+ [prompt], padding="max_length", max_length=77
51
+ ).input_ids
52
+ # (Batch_Size, Seq_Len)
53
+ cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
54
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
55
+ cond_context = clip(cond_tokens)
56
+ # Convert into a list of length Seq_Len=77
57
+ uncond_tokens = tokenizer.batch_encode_plus(
58
+ [uncond_prompt], padding="max_length", max_length=77
59
+ ).input_ids
60
+ # (Batch_Size, Seq_Len)
61
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
62
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
63
+ uncond_context = clip(uncond_tokens)
64
+ # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
65
+ context = torch.cat([cond_context, uncond_context])
66
+ else:
67
+ # Convert into a list of length Seq_Len=77
68
+ tokens = tokenizer.batch_encode_plus(
69
+ [prompt], padding="max_length", max_length=77
70
+ ).input_ids
71
+ # (Batch_Size, Seq_Len)
72
+ tokens = torch.tensor(tokens, dtype=torch.long, device=device)
73
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
74
+ context = clip(tokens)
75
+ to_idle(clip)
76
+
77
+ if sampler_name == "ddpm":
78
+ sampler = DDPMSampler(generator)
79
+ sampler.set_inference_timesteps(n_inference_steps)
80
+ else:
81
+ raise ValueError("Unknown sampler value %s. ")
82
+
83
+ latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
84
+
85
+ if input_image.any():
86
+ encoder = models["encoder"]
87
+ encoder.to(device)
88
+
89
+ input_image_tensor = input_image.resize((WIDTH, HEIGHT))
90
+ # (Height, Width, Channel)
91
+ input_image_tensor = np.array(input_image_tensor)
92
+ # (Height, Width, Channel) -> (Height, Width, Channel)
93
+ input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
94
+ # (Height, Width, Channel) -> (Height, Width, Channel)
95
+ input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
96
+ # (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
97
+ input_image_tensor = input_image_tensor.unsqueeze(0)
98
+ # (Batch_Size, Height, Width, Channel) -> (Batch_Size, Channel, Height, Width)
99
+ input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)
100
+
101
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
102
+ encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
103
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
104
+ latents = encoder(input_image_tensor, encoder_noise)
105
+
106
+ # Add noise to the latents (the encoded input image)
107
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
108
+ sampler.set_strength(strength=strength)
109
+ latents = sampler.add_noise(latents, sampler.timesteps[0])
110
+
111
+ to_idle(encoder)
112
+ else:
113
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
114
+ latents = torch.randn(latents_shape, generator=generator, device=device)
115
+
116
+ diffusion = models["diffusion"]
117
+ diffusion.to(device)
118
+
119
+ timesteps = tqdm(sampler.timesteps)
120
+ for i, timestep in enumerate(timesteps):
121
+ # (1, 320)
122
+ time_embedding = get_time_embedding(timestep).to(device)
123
+
124
+ # (Batch_Size, 4, Latents_Height, Latents_Width)
125
+ model_input = latents
126
+
127
+ if do_cfg:
128
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (2 * Batch_Size, 4, Latents_Height, Latents_Width)
129
+ model_input = model_input.repeat(2, 1, 1, 1)
130
+
131
+ # model_output is the predicted noise
132
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
133
+ model_output = diffusion(model_input, context, time_embedding)
134
+
135
+ if do_cfg:
136
+ output_cond, output_uncond = model_output.chunk(2)
137
+ model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
138
+
139
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
140
+ latents = sampler.step(timestep, latents, model_output)
141
+
142
+
143
+ to_idle(diffusion)
144
+
145
+ decoder = models["decoder"]
146
+ decoder.to(device)
147
+ # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 3, Height, Width)
148
+ images = decoder(latents)
149
+ to_idle(decoder)
150
+
151
+ images = rescale(images, (-1, 1), (0, 255), clamp=True)
152
+ # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
153
+ images = images.permute(0, 2, 3, 1)
154
+ images = images.to("cpu", torch.uint8).numpy()
155
+ return images[0]
156
+
157
+ def rescale(x, old_range, new_range, clamp=False):
158
+ old_min, old_max = old_range
159
+ new_min, new_max = new_range
160
+ x -= old_min
161
+ x *= (new_max - new_min) / (old_max - old_min)
162
+ x += new_min
163
+ if clamp:
164
+ x = x.clamp(new_min, new_max)
165
+ return x
166
+
167
+ def get_time_embedding(timestep):
168
+ # Shape: (160,)
169
+ freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
170
+ # Shape: (1, 160)
171
+ x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
172
+ # Shape: (1, 160 * 2)
173
  return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)