NZUONG commited on
Commit
1913ec5
·
verified ·
1 Parent(s): 7083e95

Upload 27 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/sample.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,24 @@
1
- ---
2
- title: Mude
3
- emoji: 🐢
4
- colorFrom: pink
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.42.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ title: AI Audio Editor
4
+
5
+ emoji: 🎵
6
+
7
+ colorFrom: blue
8
+
9
+ colorTo: purple
10
+
11
+ sdk: gradio
12
+
13
+ sdk\_version: 4.0.0
14
+
15
+ app\_file: app.py
16
+
17
+ pinned: false
18
+
19
+ license: mit
20
+
21
+ ---
22
+
23
+
24
+
app.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ from tqdm import tqdm
7
+ from transformers import UMT5EncoderModel, AutoTokenizer
8
+ from huggingface_hub import hf_hub_download, snapshot_download
9
+ import json
10
+ import numpy as np
11
+ import tempfile
12
+ from io import BytesIO
13
+ import warnings
14
+ warnings.filterwarnings("ignore")
15
+
16
+ # Import model components
17
+ from model.ae.music_dcae import MusicDCAE
18
+ from model.ldm.editing_unet import EditingUNet
19
+ from model.ldm.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
20
+
21
+ # Configuration
22
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
24
+
25
+ # Model repository - UPDATE THIS TO YOUR MODEL REPO
26
+ MODEL_REPO = "NZUONG/mude" # Your uploaded model repository
27
+
28
+ # DDPM Parameters
29
+ DDPM_NUM_TIMESTEPS = 1000
30
+ DDPM_BETA_START = 0.0001
31
+ DDPM_BETA_END = 0.02
32
+
33
+ class AttrDict(dict):
34
+ def __init__(self, *args, **kwargs):
35
+ super(AttrDict, self).__init__(*args, **kwargs)
36
+ self.__dict__ = self
37
+
38
+ def download_models():
39
+ """Download models from Hugging Face Hub"""
40
+ print("🔄 Downloading models from Hugging Face Hub...")
41
+
42
+ # Create local directories
43
+ os.makedirs("checkpoints", exist_ok=True)
44
+
45
+ try:
46
+ # Download the entire repository
47
+ local_dir = snapshot_download(
48
+ repo_id=MODEL_REPO,
49
+ cache_dir="./cache",
50
+ local_dir="./checkpoints",
51
+ repo_type="model"
52
+ )
53
+ print(f"✅ Models downloaded to: {local_dir}")
54
+ return True
55
+ except Exception as e:
56
+ print(f"❌ Error downloading models: {e}")
57
+ return False
58
+
59
+ class AudioEditor:
60
+ def __init__(self):
61
+ self.dcae = None
62
+ self.tokenizer = None
63
+ self.text_encoder = None
64
+ self.model = None
65
+ self.is_loaded = False
66
+
67
+ def load_models(self):
68
+ """Load all models once at startup"""
69
+ if self.is_loaded:
70
+ return True
71
+
72
+ # Download models if not present
73
+ if not os.path.exists("checkpoints/music_dcae_f8c8"):
74
+ print("📥 Models not found locally, downloading...")
75
+ if not download_models():
76
+ return False
77
+
78
+ print("🔄 Loading models...")
79
+
80
+ try:
81
+ # Model paths
82
+ dcae_path = "checkpoints/music_dcae_f8c8"
83
+ vocoder_path = "checkpoints/music_vocoder"
84
+ t5_path = "checkpoints/umt5-base"
85
+ unet_config_path = "model/ldm/exp_config.json"
86
+ trained_model_path = "checkpoints/fm_checkpoint_epoch_9.pt"
87
+
88
+ # Load DCAE
89
+ self.dcae = MusicDCAE(
90
+ dcae_checkpoint_path=dcae_path,
91
+ vocoder_checkpoint_path=vocoder_path
92
+ ).to(DEVICE).eval()
93
+
94
+ # Load text encoder
95
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_path)
96
+ self.text_encoder = UMT5EncoderModel.from_pretrained(t5_path).to(DEVICE, dtype=DTYPE).eval()
97
+
98
+ # Load UNet config
99
+ with open(unet_config_path, 'r') as f:
100
+ unet_config = AttrDict(json.load(f)['model']['unet'])
101
+
102
+ self.model = EditingUNet(unet_config, use_flow_matching=False).to("cpu", dtype=DTYPE).eval()
103
+
104
+ # Load checkpoint
105
+ checkpoint = torch.load(trained_model_path, map_location="cpu")
106
+ model_state_dict = checkpoint.get('model_state_dict', checkpoint)
107
+ if any(key.startswith('_orig_mod.') for key in model_state_dict.keys()):
108
+ model_state_dict = {key.replace('_orig_mod.', ''): value for key, value in model_state_dict.items()}
109
+ self.model.load_state_dict(model_state_dict, strict=False)
110
+
111
+ self.is_loaded = True
112
+ print("✅ All models loaded successfully!")
113
+ return True
114
+
115
+ except Exception as e:
116
+ print(f"❌ Error loading models: {e}")
117
+ return False
118
+
119
+ def dpm_solver_sampling(self, model, source_latent, instruction_embedding, uncond_embedding,
120
+ strength=1.0, steps=25, guidance_scale=7.5, seed=42):
121
+ """DPM-Solver sampling function"""
122
+ print(f"🚀 Starting DPM-Solver++ sampling with {steps} steps...")
123
+
124
+ # Setup noise schedule
125
+ betas = torch.linspace(DDPM_BETA_START, DDPM_BETA_END, DDMP_NUM_TIMESTEPS, dtype=torch.float32)
126
+ alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
127
+ noise_schedule = NoiseScheduleVP(schedule='discrete', alphas_cumprod=alphas_cumprod)
128
+
129
+ # Setup model wrapper
130
+ model_fn = model_wrapper(
131
+ model,
132
+ noise_schedule,
133
+ model_type="noise", # DDPM objective only
134
+ model_kwargs={
135
+ "source_latent": source_latent,
136
+ },
137
+ guidance_type="classifier-free",
138
+ condition=instruction_embedding,
139
+ unconditional_condition=uncond_embedding,
140
+ guidance_scale=guidance_scale,
141
+ )
142
+
143
+ # Initialize DPM-Solver++
144
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
145
+
146
+ # Calculate time range
147
+ t_end = noise_schedule.T / noise_schedule.total_N
148
+ t_start = t_end + strength * (noise_schedule.T - t_end)
149
+
150
+ # Add initial noise
151
+ torch.manual_seed(seed)
152
+ noise = torch.randn_like(source_latent)
153
+ latents = dpm_solver.add_noise(source_latent, torch.tensor([t_start], device=DEVICE), noise)
154
+ latents = latents.to(DTYPE)
155
+
156
+ # Run DPM solver sampling
157
+ with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
158
+ with torch.no_grad():
159
+ final_latent, _ = dpm_solver.sample(
160
+ latents,
161
+ steps=steps,
162
+ t_start=t_start,
163
+ t_end=t_end,
164
+ order=2,
165
+ method="multistep",
166
+ skip_type="time_uniform",
167
+ lower_order_final=True,
168
+ return_intermediate=True,
169
+ )
170
+
171
+ return final_latent
172
+
173
+ def process_audio(self, audio_file, instruction, guidance_scale, steps, strength, seed):
174
+ """Main audio processing function"""
175
+ try:
176
+ if not self.load_models():
177
+ return None, None, "❌ Failed to load models. Please try again."
178
+
179
+ # Load and preprocess audio
180
+ print(f"🎵 Processing audio: {audio_file}")
181
+ audio, sr = torchaudio.load(audio_file)
182
+ TARGET_SR_DCAE = 44100
183
+ TARGET_LEN_DCAE = TARGET_SR_DCAE * 10
184
+
185
+ if sr != TARGET_SR_DCAE:
186
+ audio = torchaudio.transforms.Resample(sr, TARGET_SR_DCAE)(audio)
187
+
188
+ if audio.shape[1] > TARGET_LEN_DCAE:
189
+ audio = audio[:, :TARGET_LEN_DCAE]
190
+ elif audio.shape[1] < TARGET_LEN_DCAE:
191
+ audio = torch.nn.functional.pad(audio, (0, TARGET_LEN_DCAE - audio.shape[1]))
192
+
193
+ if audio.shape[0] == 1:
194
+ audio = audio.repeat(2, 1)
195
+
196
+ # Encode audio
197
+ with torch.no_grad():
198
+ source_latent_scaled, _ = self.dcae.encode(audio.to(DEVICE).unsqueeze(0))
199
+
200
+ # Prepare text embeddings
201
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
202
+ text_input = self.tokenizer([instruction], max_length=32, padding="max_length",
203
+ truncation=True, return_tensors="pt")
204
+ instruction_embedding = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
205
+
206
+ uncond_input = self.tokenizer([""], max_length=32, padding="max_length",
207
+ truncation=True, return_tensors="pt")
208
+ uncond_embedding = self.text_encoder(uncond_input.input_ids.to(DEVICE))[0]
209
+
210
+ # Move models for inference
211
+ self.dcae = self.dcae.cpu()
212
+ torch.cuda.empty_cache()
213
+ self.model = self.model.to(DEVICE, dtype=DTYPE)
214
+
215
+ # Generate
216
+ print("🎨 Generating edited audio...")
217
+ with torch.amp.autocast(device_type="cuda", dtype=DTYPE, enabled=(DTYPE != torch.float32)):
218
+ with torch.no_grad():
219
+ final_latent = self.dpm_solver_sampling(
220
+ model=self.model,
221
+ source_latent=source_latent_scaled,
222
+ instruction_embedding=instruction_embedding,
223
+ uncond_embedding=uncond_embedding,
224
+ strength=strength,
225
+ steps=int(steps),
226
+ guidance_scale=guidance_scale,
227
+ seed=int(seed)
228
+ )
229
+
230
+ # Decode results
231
+ self.model = self.model.cpu()
232
+ torch.cuda.empty_cache()
233
+ self.dcae = self.dcae.to(DEVICE)
234
+
235
+ final_latent_unscaled = (final_latent.float() / self.dcae.scale_factor) + self.dcae.shift_factor
236
+ source_latent_raw = (source_latent_scaled / self.dcae.scale_factor) + self.dcae.shift_factor
237
+
238
+ with torch.no_grad():
239
+ source_mel = self.dcae.decode_to_mel(source_latent_raw)
240
+ edited_mel = self.dcae.decode_to_mel(final_latent_unscaled)
241
+ _, pred_wavs = self.dcae.decode(latents=final_latent.float(), sr=44100)
242
+ edited_audio = pred_wavs[0]
243
+
244
+ # Create comparison plot
245
+ comparison_plot = self.create_mel_comparison(source_mel, edited_mel, instruction)
246
+
247
+ # Save output audio
248
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
249
+ torchaudio.save(tmp_file.name, edited_audio.cpu().float(), 44100)
250
+ output_path = tmp_file.name
251
+
252
+ # Cleanup
253
+ self.dcae = self.dcae.cpu()
254
+ torch.cuda.empty_cache()
255
+
256
+ return output_path, comparison_plot, f"✅ Audio editing completed! Instruction: '{instruction}'"
257
+
258
+ except Exception as e:
259
+ import traceback
260
+ error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
261
+ print(error_msg)
262
+ return None, None, error_msg
263
+
264
+ def create_mel_comparison(self, source_mel, edited_mel, instruction):
265
+ """Create mel-spectrogram comparison plot"""
266
+ try:
267
+ source_mel_np = source_mel.squeeze(0)[0].cpu().float().numpy()
268
+ edited_mel_np = edited_mel.squeeze(0)[0].cpu().float().numpy()
269
+
270
+ fig, axs = plt.subplots(2, 1, figsize=(12, 8), sharex=True, sharey=True)
271
+ fig.suptitle(f'Mel-Spectrogram Comparison', fontsize=14)
272
+
273
+ # Plot source
274
+ im1 = axs[0].imshow(source_mel_np, aspect='auto', origin='lower', cmap='viridis')
275
+ axs[0].set_title('Original Audio')
276
+ axs[0].set_ylabel('Mel Bins')
277
+ plt.colorbar(im1, ax=axs[0])
278
+
279
+ # Plot edited
280
+ im2 = axs[1].imshow(edited_mel_np, aspect='auto', origin='lower', cmap='viridis')
281
+ axs[1].set_title(f'Edited Audio: "{instruction}"')
282
+ axs[1].set_ylabel('Mel Bins')
283
+ axs[1].set_xlabel('Time Frames')
284
+ plt.colorbar(im2, ax=axs[1])
285
+
286
+ plt.tight_layout()
287
+
288
+ # Save to temporary file for Gradio
289
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
290
+ plt.savefig(tmp_file.name, dpi=100, bbox_inches='tight')
291
+ plt.close()
292
+ return tmp_file.name
293
+
294
+ except Exception as e:
295
+ print(f"Error creating plot: {e}")
296
+ plt.close()
297
+ return None
298
+
299
+ # Initialize the audio editor
300
+ audio_editor = AudioEditor()
301
+
302
+ def gradio_interface(audio_file, instruction, guidance_scale, steps, strength, seed):
303
+ """Gradio interface function"""
304
+ if audio_file is None:
305
+ return None, None, "Please upload an audio file"
306
+
307
+ if not instruction.strip():
308
+ return None, None, "Please provide an editing instruction"
309
+
310
+ return audio_editor.process_audio(audio_file, instruction, guidance_scale, steps, strength, seed)
311
+
312
+ # Create Gradio interface
313
+ with gr.Blocks(title="🎵 AI Audio Editor", theme=gr.themes.Soft()) as demo:
314
+ gr.HTML("""
315
+ <div style="text-align: center; margin-bottom: 20px;">
316
+ <h1>🎵 AI Audio Editor</h1>
317
+ <p>Upload an audio file and provide instructions to edit it using AI.<br/>
318
+ The model uses DPM-Solver++ for fast, high-quality generation.</p>
319
+ </div>
320
+ """)
321
+
322
+ with gr.Row():
323
+ with gr.Column(scale=1):
324
+ # Input components
325
+ audio_input = gr.Audio(
326
+ label="📁 Upload Audio File",
327
+ type="filepath"
328
+ )
329
+
330
+ instruction_input = gr.Textbox(
331
+ label="✏️ Editing Instruction",
332
+ placeholder="e.g., 'Add drums', 'Make it more energetic', 'Remove vocals'",
333
+ lines=2
334
+ )
335
+
336
+ with gr.Accordion("🔧 Advanced Settings", open=False):
337
+ guidance_scale = gr.Slider(
338
+ minimum=1.0, maximum=20.0, value=7.5, step=0.5,
339
+ label="Guidance Scale",
340
+ info="Higher values follow the instruction more closely"
341
+ )
342
+
343
+ steps = gr.Slider(
344
+ minimum=10, maximum=50, value=25, step=5,
345
+ label="Sampling Steps",
346
+ info="More steps = better quality, slower generation"
347
+ )
348
+
349
+ strength = gr.Slider(
350
+ minimum=0.1, maximum=1.0, value=1.0, step=0.1,
351
+ label="Denoising Strength",
352
+ info="1.0 = full denoising, lower = more conservative editing"
353
+ )
354
+
355
+ seed = gr.Number(
356
+ value=42, label="Seed",
357
+ info="For reproducible results"
358
+ )
359
+
360
+ generate_btn = gr.Button("🎨 Generate Edited Audio", variant="primary", size="lg")
361
+
362
+ with gr.Column(scale=1):
363
+ # Output components
364
+ status_output = gr.Textbox(label="📊 Status", interactive=False)
365
+ audio_output = gr.Audio(label="🎵 Generated Audio")
366
+ plot_output = gr.Image(label="📈 Mel-Spectrogram Comparison")
367
+
368
+ gr.HTML("""
369
+ <div style="margin-top: 20px; padding: 20px; background-color: #f0f0f0; border-radius: 10px;">
370
+ <h3>📝 Usage Tips:</h3>
371
+ <ul>
372
+ <li><b>Audio Length:</b> Files are automatically processed to 10 seconds</li>
373
+ <li><b>Instructions:</b> Be specific (e.g., "Add heavy drums" vs "Add drums")</li>
374
+ <li><b>Guidance Scale:</b> Start with 7.5, increase for stronger effects</li>
375
+ <li><b>Steps:</b> 25 steps provide good quality/speed balance</li>
376
+ </ul>
377
+ </div>
378
+ """)
379
+
380
+ # Connect the interface
381
+ generate_btn.click(
382
+ fn=gradio_interface,
383
+ inputs=[audio_input, instruction_input, guidance_scale, steps, strength, seed],
384
+ outputs=[audio_output, plot_output, status_output],
385
+ show_progress=True
386
+ )
387
+
388
+ # Launch settings
389
+ if __name__ == "__main__":
390
+ demo.launch(
391
+ server_name="0.0.0.0",
392
+ server_port=7860,
393
+ share=False,
394
+ show_error=True
395
+ )
examples/sample.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b94ed3260e322a90dc10b88a3fd1c4d1ad5da50a7f40d62d976d7a59a495eee9
3
+ size 3528078
model/__pycache__/scheduler.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
model/ae/__pycache__/music_dcae.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
model/ae/__pycache__/music_log_mel.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
model/ae/__pycache__/music_vocoder.cpython-310.pyc ADDED
Binary file (15.7 kB). View file
 
model/ae/music_dcae.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from diffusers import AutoencoderDC
4
+ import torchaudio
5
+ import torchvision.transforms as transforms
6
+
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+ from diffusers.loaders import FromOriginalModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from tqdm import tqdm
11
+
12
+ try:
13
+ from .music_vocoder import ADaMoSHiFiGANV1
14
+ except ImportError:
15
+ from music_vocoder import ADaMoSHiFiGANV1
16
+
17
+ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18
+ DEFAULT_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_dcae_f8c8")
19
+ VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder")
20
+
21
+ print(DEFAULT_PRETRAINED_PATH)
22
+
23
+ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
24
+ @register_to_config
25
+ def __init__(
26
+ self,
27
+ source_sample_rate=None,
28
+ dcae_checkpoint_path= "D:\do an\checkpoints\music_dcae_f8c8", #DEFAULT_PRETRAINED_PATH ,
29
+ vocoder_checkpoint_path= "D:\do an\checkpoints\music_vocoder" #VOCODER_PRETRAINED_PATH,
30
+ ):
31
+ super(MusicDCAE, self).__init__()
32
+
33
+ self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
34
+ self.vocoder = ADaMoSHiFiGANV1.from_pretrained(vocoder_checkpoint_path)
35
+
36
+ if source_sample_rate is None:
37
+ source_sample_rate = 48000
38
+
39
+ self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
40
+
41
+ self.transform = transforms.Compose(
42
+ [
43
+ transforms.Normalize(0.5, 0.5),
44
+ ]
45
+ )
46
+ self.min_mel_value = -11.0
47
+ self.max_mel_value = 3.0
48
+ self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
49
+ self.mel_chunk_size = 1024
50
+ self.time_dimention_multiple = 8
51
+ self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
52
+ self.scale_factor = 0.1786
53
+ self.shift_factor = -1.9091
54
+
55
+ def load_audio(self, audio_path):
56
+ audio, sr = torchaudio.load(audio_path)
57
+ if audio.shape[0] == 1:
58
+ audio = audio.repeat(2, 1)
59
+ return audio, sr
60
+
61
+ def forward_mel(self, audios):
62
+ mels = []
63
+ for i in range(len(audios)):
64
+ image = self.vocoder.mel_transform(audios[i])
65
+ mels.append(image)
66
+ mels = torch.stack(mels)
67
+ return mels
68
+
69
+ @torch.no_grad()
70
+ def encode(self, audios, audio_lengths=None, sr=None):
71
+ if audio_lengths is None:
72
+ audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
73
+ audio_lengths = audio_lengths.to(audios.device)
74
+
75
+ # audios: N x 2 x T, 48kHz
76
+ device = audios.device
77
+ dtype = audios.dtype
78
+
79
+ if sr is None:
80
+ sr = 48000
81
+ resampler = self.resampler
82
+ else:
83
+ resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
84
+
85
+ audio = resampler(audios)
86
+
87
+ max_audio_len = audio.shape[-1]
88
+ if max_audio_len % (8 * 512) != 0:
89
+ audio = torch.nn.functional.pad(
90
+ audio, (0, 8 * 512 - max_audio_len % (8 * 512))
91
+ )
92
+
93
+ mels = self.forward_mel(audio)
94
+ mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
95
+ mels = self.transform(mels)
96
+ latents = []
97
+ for mel in mels:
98
+ latent = self.dcae.encoder(mel.unsqueeze(0))
99
+ latents.append(latent)
100
+ latents = torch.cat(latents, dim=0)
101
+ latent_lengths = (
102
+ audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple
103
+ ).long()
104
+ latents = (latents - self.shift_factor) * self.scale_factor
105
+ return latents, latent_lengths
106
+
107
+ @torch.no_grad()
108
+ def decode(self, latents, audio_lengths=None, sr=None):
109
+ latents = latents / self.scale_factor + self.shift_factor
110
+
111
+ pred_wavs = []
112
+
113
+ for latent in latents:
114
+ mels = self.dcae.decoder(latent.unsqueeze(0))
115
+ mels = mels * 0.5 + 0.5
116
+ mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
117
+
118
+ # wav = self.vocoder.decode(mels[0]).squeeze(1)
119
+ # decode waveform for each channels to reduce vram footprint
120
+ wav_ch1 = self.vocoder.decode(mels[:,0,:,:]).squeeze(1).cpu()
121
+ wav_ch2 = self.vocoder.decode(mels[:,1,:,:]).squeeze(1).cpu()
122
+ wav = torch.cat([wav_ch1, wav_ch2],dim=0)
123
+
124
+ if sr is not None:
125
+ resampler = (
126
+ torchaudio.transforms.Resample(44100, sr)
127
+ )
128
+ wav = resampler(wav.cpu().float())
129
+ else:
130
+ sr = 44100
131
+ pred_wavs.append(wav)
132
+
133
+ if audio_lengths is not None:
134
+ pred_wavs = [
135
+ wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)
136
+ ]
137
+ return sr, pred_wavs
138
+ @torch.no_grad()
139
+ def decode_to_mel(self, latents):
140
+ """
141
+ Decodes latent representations into mel-spectrograms using the DCAE decoder.
142
+
143
+ Args:
144
+ latents (torch.Tensor): A batch of latent tensors to decode, typically of shape (batch_size, ...).
145
+
146
+ Returns:
147
+ list of torch.Tensor: A list of mel-spectrogram tensors corresponding to each input latent.
148
+ """
149
+ # Un-scale latent theo logic của DCAE
150
+ #latents_for_decoder = (latents - self.shift_factor) * self.scale_factor
151
+
152
+ # Ensure latents have the same dtype as the decoder's parameters
153
+ # Convert to float32 to match the bias type
154
+ latents = latents.float()
155
+
156
+ # Process each latent individually like in the decode method
157
+ mels_list = []
158
+ for latent in latents:
159
+ mel = self.dcae.decoder(latent.unsqueeze(0))
160
+ mel = mel * 0.5 + 0.5
161
+ mel = mel * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
162
+ mels_list.append(mel)
163
+
164
+ # Concatenate all mels if multiple latents were processed
165
+ if len(mels_list) == 1:
166
+ return mels_list[0]
167
+ else:
168
+ return torch.cat(mels_list, dim=0)
169
+
model/ae/music_log_mel.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step: A Step Towards Music Generation Foundation Model
3
+
4
+ https://github.com/ace-step/ACE-Step
5
+
6
+ Apache 2.0 License
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ from torchaudio.transforms import MelScale
13
+
14
+
15
+ class LinearSpectrogram(nn.Module):
16
+ def __init__(
17
+ self,
18
+ n_fft=2048,
19
+ win_length=2048,
20
+ hop_length=512,
21
+ center=False,
22
+ mode="pow2_sqrt",
23
+ ):
24
+ super().__init__()
25
+
26
+ self.n_fft = n_fft
27
+ self.win_length = win_length
28
+ self.hop_length = hop_length
29
+ self.center = center
30
+ self.mode = mode
31
+
32
+ self.register_buffer("window", torch.hann_window(win_length))
33
+
34
+ def forward(self, y: Tensor) -> Tensor:
35
+ if y.ndim == 3:
36
+ y = y.squeeze(1)
37
+
38
+ y = torch.nn.functional.pad(
39
+ y.unsqueeze(1),
40
+ (
41
+ (self.win_length - self.hop_length) // 2,
42
+ (self.win_length - self.hop_length + 1) // 2,
43
+ ),
44
+ mode="reflect",
45
+ ).squeeze(1)
46
+ dtype = y.dtype
47
+ spec = torch.stft(
48
+ y.float(),
49
+ self.n_fft,
50
+ hop_length=self.hop_length,
51
+ win_length=self.win_length,
52
+ window=self.window,
53
+ center=self.center,
54
+ pad_mode="reflect",
55
+ normalized=False,
56
+ onesided=True,
57
+ return_complex=True,
58
+ )
59
+ spec = torch.view_as_real(spec)
60
+
61
+ if self.mode == "pow2_sqrt":
62
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
63
+ spec = spec.to(dtype)
64
+ return spec
65
+
66
+
67
+ class LogMelSpectrogram(nn.Module):
68
+ def __init__(
69
+ self,
70
+ sample_rate=44100,
71
+ n_fft=2048,
72
+ win_length=2048,
73
+ hop_length=512,
74
+ n_mels=128,
75
+ center=False,
76
+ f_min=0.0,
77
+ f_max=None,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.sample_rate = sample_rate
82
+ self.n_fft = n_fft
83
+ self.win_length = win_length
84
+ self.hop_length = hop_length
85
+ self.center = center
86
+ self.n_mels = n_mels
87
+ self.f_min = f_min
88
+ self.f_max = f_max or sample_rate // 2
89
+
90
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
91
+ self.mel_scale = MelScale(
92
+ self.n_mels,
93
+ self.sample_rate,
94
+ self.f_min,
95
+ self.f_max,
96
+ self.n_fft // 2 + 1,
97
+ "slaney",
98
+ "slaney",
99
+ )
100
+
101
+ def compress(self, x: Tensor) -> Tensor:
102
+ return torch.log(torch.clamp(x, min=1e-5))
103
+
104
+ def decompress(self, x: Tensor) -> Tensor:
105
+ return torch.exp(x)
106
+
107
+ def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
108
+ linear = self.spectrogram(x)
109
+ x = self.mel_scale(linear)
110
+ x = self.compress(x)
111
+ # print(x.shape)
112
+ if return_linear:
113
+ return x, self.compress(linear)
114
+
115
+ return x
model/ae/music_vocoder.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step: A Step Towards Music Generation Foundation Model
3
+
4
+ https://github.com/ace-step/ACE-Step
5
+
6
+ Apache 2.0 License
7
+ """
8
+
9
+ import librosa
10
+ import torch
11
+ from torch import nn
12
+
13
+ from functools import partial
14
+ from math import prod
15
+ from typing import Callable, Tuple, List
16
+
17
+ import numpy as np
18
+ import torch.nn.functional as F
19
+ from torch.nn import Conv1d
20
+ from torch.nn.utils import weight_norm
21
+ from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.loaders import FromOriginalModelMixin
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+
26
+
27
+ try:
28
+ from music_log_mel import LogMelSpectrogram
29
+ except ImportError:
30
+ from .music_log_mel import LogMelSpectrogram
31
+
32
+
33
+ def drop_path(
34
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
35
+ ):
36
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
37
+
38
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
39
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
40
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
41
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
42
+ 'survival rate' as the argument.
43
+
44
+ """ # noqa: E501
45
+
46
+ if drop_prob == 0.0 or not training:
47
+ return x
48
+ keep_prob = 1 - drop_prob
49
+ shape = (x.shape[0],) + (1,) * (
50
+ x.ndim - 1
51
+ ) # work with diff dim tensors, not just 2D ConvNets
52
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
53
+ if keep_prob > 0.0 and scale_by_keep:
54
+ random_tensor.div_(keep_prob)
55
+ return x * random_tensor
56
+
57
+
58
+ class DropPath(nn.Module):
59
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
60
+
61
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
62
+ super(DropPath, self).__init__()
63
+ self.drop_prob = drop_prob
64
+ self.scale_by_keep = scale_by_keep
65
+
66
+ def forward(self, x):
67
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
68
+
69
+ def extra_repr(self):
70
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
71
+
72
+
73
+ class LayerNorm(nn.Module):
74
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
75
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
76
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
77
+ with shape (batch_size, channels, height, width).
78
+ """ # noqa: E501
79
+
80
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
81
+ super().__init__()
82
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
83
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
84
+ self.eps = eps
85
+ self.data_format = data_format
86
+ if self.data_format not in ["channels_last", "channels_first"]:
87
+ raise NotImplementedError
88
+ self.normalized_shape = (normalized_shape,)
89
+
90
+ def forward(self, x):
91
+ if self.data_format == "channels_last":
92
+ return F.layer_norm(
93
+ x, self.normalized_shape, self.weight, self.bias, self.eps
94
+ )
95
+ elif self.data_format == "channels_first":
96
+ u = x.mean(1, keepdim=True)
97
+ s = (x - u).pow(2).mean(1, keepdim=True)
98
+ x = (x - u) / torch.sqrt(s + self.eps)
99
+ x = self.weight[:, None] * x + self.bias[:, None]
100
+ return x
101
+
102
+
103
+ class ConvNeXtBlock(nn.Module):
104
+ r"""ConvNeXt Block. There are two equivalent implementations:
105
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
106
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
107
+ We use (2) as we find it slightly faster in PyTorch
108
+
109
+ Args:
110
+ dim (int): Number of input channels.
111
+ drop_path (float): Stochastic depth rate. Default: 0.0
112
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
113
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
114
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
115
+ dilation (int): Dilation for depthwise conv. Default: 1.
116
+ """ # noqa: E501
117
+
118
+ def __init__(
119
+ self,
120
+ dim: int,
121
+ drop_path: float = 0.0,
122
+ layer_scale_init_value: float = 1e-6,
123
+ mlp_ratio: float = 4.0,
124
+ kernel_size: int = 7,
125
+ dilation: int = 1,
126
+ ):
127
+ super().__init__()
128
+
129
+ self.dwconv = nn.Conv1d(
130
+ dim,
131
+ dim,
132
+ kernel_size=kernel_size,
133
+ padding=int(dilation * (kernel_size - 1) / 2),
134
+ groups=dim,
135
+ ) # depthwise conv
136
+ self.norm = LayerNorm(dim, eps=1e-6)
137
+ self.pwconv1 = nn.Linear(
138
+ dim, int(mlp_ratio * dim)
139
+ ) # pointwise/1x1 convs, implemented with linear layers
140
+ self.act = nn.GELU()
141
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
142
+ self.gamma = (
143
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
144
+ if layer_scale_init_value > 0
145
+ else None
146
+ )
147
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
148
+
149
+ def forward(self, x, apply_residual: bool = True):
150
+ input = x
151
+
152
+ x = self.dwconv(x)
153
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
154
+ x = self.norm(x)
155
+ x = self.pwconv1(x)
156
+ x = self.act(x)
157
+ x = self.pwconv2(x)
158
+
159
+ if self.gamma is not None:
160
+ x = self.gamma * x
161
+
162
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
163
+ x = self.drop_path(x)
164
+
165
+ if apply_residual:
166
+ x = input + x
167
+
168
+ return x
169
+
170
+
171
+ class ParallelConvNeXtBlock(nn.Module):
172
+ def __init__(self, kernel_sizes: List[int], *args, **kwargs):
173
+ super().__init__()
174
+ self.blocks = nn.ModuleList(
175
+ [
176
+ ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
177
+ for kernel_size in kernel_sizes
178
+ ]
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ return torch.stack(
183
+ [block(x, apply_residual=False) for block in self.blocks] + [x],
184
+ dim=1,
185
+ ).sum(dim=1)
186
+
187
+
188
+ class ConvNeXtEncoder(nn.Module):
189
+ def __init__(
190
+ self,
191
+ input_channels=3,
192
+ depths=[3, 3, 9, 3],
193
+ dims=[96, 192, 384, 768],
194
+ drop_path_rate=0.0,
195
+ layer_scale_init_value=1e-6,
196
+ kernel_sizes: Tuple[int] = (7,),
197
+ ):
198
+ super().__init__()
199
+ assert len(depths) == len(dims)
200
+
201
+ self.channel_layers = nn.ModuleList()
202
+ stem = nn.Sequential(
203
+ nn.Conv1d(
204
+ input_channels,
205
+ dims[0],
206
+ kernel_size=7,
207
+ padding=3,
208
+ padding_mode="replicate",
209
+ ),
210
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
211
+ )
212
+ self.channel_layers.append(stem)
213
+
214
+ for i in range(len(depths) - 1):
215
+ mid_layer = nn.Sequential(
216
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
217
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
218
+ )
219
+ self.channel_layers.append(mid_layer)
220
+
221
+ block_fn = (
222
+ partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
223
+ if len(kernel_sizes) == 1
224
+ else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
225
+ )
226
+
227
+ self.stages = nn.ModuleList()
228
+ drop_path_rates = [
229
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
230
+ ]
231
+
232
+ cur = 0
233
+ for i in range(len(depths)):
234
+ stage = nn.Sequential(
235
+ *[
236
+ block_fn(
237
+ dim=dims[i],
238
+ drop_path=drop_path_rates[cur + j],
239
+ layer_scale_init_value=layer_scale_init_value,
240
+ )
241
+ for j in range(depths[i])
242
+ ]
243
+ )
244
+ self.stages.append(stage)
245
+ cur += depths[i]
246
+
247
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
248
+ self.apply(self._init_weights)
249
+
250
+ def _init_weights(self, m):
251
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
252
+ nn.init.trunc_normal_(m.weight, std=0.02)
253
+ nn.init.constant_(m.bias, 0)
254
+
255
+ def forward(
256
+ self,
257
+ x: torch.Tensor,
258
+ ) -> torch.Tensor:
259
+ for channel_layer, stage in zip(self.channel_layers, self.stages):
260
+ x = channel_layer(x)
261
+ x = stage(x)
262
+
263
+ return self.norm(x)
264
+
265
+
266
+ def init_weights(m, mean=0.0, std=0.01):
267
+ classname = m.__class__.__name__
268
+ if classname.find("Conv") != -1:
269
+ m.weight.data.normal_(mean, std)
270
+
271
+
272
+ def get_padding(kernel_size, dilation=1):
273
+ return (kernel_size * dilation - dilation) // 2
274
+
275
+
276
+ class ResBlock1(torch.nn.Module):
277
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
278
+ super().__init__()
279
+
280
+ self.convs1 = nn.ModuleList(
281
+ [
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=dilation[0],
289
+ padding=get_padding(kernel_size, dilation[0]),
290
+ )
291
+ ),
292
+ weight_norm(
293
+ Conv1d(
294
+ channels,
295
+ channels,
296
+ kernel_size,
297
+ 1,
298
+ dilation=dilation[1],
299
+ padding=get_padding(kernel_size, dilation[1]),
300
+ )
301
+ ),
302
+ weight_norm(
303
+ Conv1d(
304
+ channels,
305
+ channels,
306
+ kernel_size,
307
+ 1,
308
+ dilation=dilation[2],
309
+ padding=get_padding(kernel_size, dilation[2]),
310
+ )
311
+ ),
312
+ ]
313
+ )
314
+ self.convs1.apply(init_weights)
315
+
316
+ self.convs2 = nn.ModuleList(
317
+ [
318
+ weight_norm(
319
+ Conv1d(
320
+ channels,
321
+ channels,
322
+ kernel_size,
323
+ 1,
324
+ dilation=1,
325
+ padding=get_padding(kernel_size, 1),
326
+ )
327
+ ),
328
+ weight_norm(
329
+ Conv1d(
330
+ channels,
331
+ channels,
332
+ kernel_size,
333
+ 1,
334
+ dilation=1,
335
+ padding=get_padding(kernel_size, 1),
336
+ )
337
+ ),
338
+ weight_norm(
339
+ Conv1d(
340
+ channels,
341
+ channels,
342
+ kernel_size,
343
+ 1,
344
+ dilation=1,
345
+ padding=get_padding(kernel_size, 1),
346
+ )
347
+ ),
348
+ ]
349
+ )
350
+ self.convs2.apply(init_weights)
351
+
352
+ def forward(self, x):
353
+ for c1, c2 in zip(self.convs1, self.convs2):
354
+ xt = F.silu(x)
355
+ xt = c1(xt)
356
+ xt = F.silu(xt)
357
+ xt = c2(xt)
358
+ x = xt + x
359
+ return x
360
+
361
+ def remove_weight_norm(self):
362
+ for conv in self.convs1:
363
+ remove_weight_norm(conv)
364
+ for conv in self.convs2:
365
+ remove_weight_norm(conv)
366
+
367
+
368
+ class HiFiGANGenerator(nn.Module):
369
+ def __init__(
370
+ self,
371
+ *,
372
+ hop_length: int = 512,
373
+ upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
374
+ upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
375
+ resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
376
+ resblock_dilation_sizes: Tuple[Tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
377
+ num_mels: int = 128,
378
+ upsample_initial_channel: int = 512,
379
+ use_template: bool = True,
380
+ pre_conv_kernel_size: int = 7,
381
+ post_conv_kernel_size: int = 7,
382
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
383
+ ):
384
+ super().__init__()
385
+
386
+ assert (
387
+ prod(upsample_rates) == hop_length
388
+ ), f"hop_length must be {prod(upsample_rates)}"
389
+
390
+ self.conv_pre = weight_norm(
391
+ nn.Conv1d(
392
+ num_mels,
393
+ upsample_initial_channel,
394
+ pre_conv_kernel_size,
395
+ 1,
396
+ padding=get_padding(pre_conv_kernel_size),
397
+ )
398
+ )
399
+
400
+ self.num_upsamples = len(upsample_rates)
401
+ self.num_kernels = len(resblock_kernel_sizes)
402
+
403
+ self.noise_convs = nn.ModuleList()
404
+ self.use_template = use_template
405
+ self.ups = nn.ModuleList()
406
+
407
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
408
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
409
+ self.ups.append(
410
+ weight_norm(
411
+ nn.ConvTranspose1d(
412
+ upsample_initial_channel // (2**i),
413
+ upsample_initial_channel // (2 ** (i + 1)),
414
+ k,
415
+ u,
416
+ padding=(k - u) // 2,
417
+ )
418
+ )
419
+ )
420
+
421
+ if not use_template:
422
+ continue
423
+
424
+ if i + 1 < len(upsample_rates):
425
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
426
+ self.noise_convs.append(
427
+ Conv1d(
428
+ 1,
429
+ c_cur,
430
+ kernel_size=stride_f0 * 2,
431
+ stride=stride_f0,
432
+ padding=stride_f0 // 2,
433
+ )
434
+ )
435
+ else:
436
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
437
+
438
+ self.resblocks = nn.ModuleList()
439
+ for i in range(len(self.ups)):
440
+ ch = upsample_initial_channel // (2 ** (i + 1))
441
+ for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
442
+ self.resblocks.append(ResBlock1(ch, k, d))
443
+
444
+ self.activation_post = post_activation()
445
+ self.conv_post = weight_norm(
446
+ nn.Conv1d(
447
+ ch,
448
+ 1,
449
+ post_conv_kernel_size,
450
+ 1,
451
+ padding=get_padding(post_conv_kernel_size),
452
+ )
453
+ )
454
+ self.ups.apply(init_weights)
455
+ self.conv_post.apply(init_weights)
456
+
457
+ def forward(self, x, template=None):
458
+ x = self.conv_pre(x)
459
+
460
+ for i in range(self.num_upsamples):
461
+ x = F.silu(x, inplace=True)
462
+ x = self.ups[i](x)
463
+
464
+ if self.use_template:
465
+ x = x + self.noise_convs[i](template)
466
+
467
+ xs = None
468
+
469
+ for j in range(self.num_kernels):
470
+ if xs is None:
471
+ xs = self.resblocks[i * self.num_kernels + j](x)
472
+ else:
473
+ xs += self.resblocks[i * self.num_kernels + j](x)
474
+
475
+ x = xs / self.num_kernels
476
+
477
+ x = self.activation_post(x)
478
+ x = self.conv_post(x)
479
+ x = torch.tanh(x)
480
+
481
+ return x
482
+
483
+ def remove_weight_norm(self):
484
+ for up in self.ups:
485
+ remove_weight_norm(up)
486
+ for block in self.resblocks:
487
+ block.remove_weight_norm()
488
+ remove_weight_norm(self.conv_pre)
489
+ remove_weight_norm(self.conv_post)
490
+
491
+
492
+ class ADaMoSHiFiGANV1(ModelMixin, ConfigMixin, FromOriginalModelMixin):
493
+
494
+ @register_to_config
495
+ def __init__(
496
+ self,
497
+ input_channels: int = 128,
498
+ depths: List[int] = [3, 3, 9, 3],
499
+ dims: List[int] = [128, 256, 384, 512],
500
+ drop_path_rate: float = 0.0,
501
+ kernel_sizes: Tuple[int] = (7,),
502
+ upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
503
+ upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
504
+ resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
505
+ resblock_dilation_sizes: Tuple[Tuple[int]] = (
506
+ (1, 3, 5),
507
+ (1, 3, 5),
508
+ (1, 3, 5),
509
+ (1, 3, 5),
510
+ ),
511
+ num_mels: int = 512,
512
+ upsample_initial_channel: int = 1024,
513
+ use_template: bool = False,
514
+ pre_conv_kernel_size: int = 13,
515
+ post_conv_kernel_size: int = 13,
516
+ sampling_rate: int = 44100,
517
+ n_fft: int = 2048,
518
+ win_length: int = 2048,
519
+ hop_length: int = 512,
520
+ f_min: int = 40,
521
+ f_max: int = 16000,
522
+ n_mels: int = 128,
523
+ ):
524
+ super().__init__()
525
+
526
+ self.backbone = ConvNeXtEncoder(
527
+ input_channels=input_channels,
528
+ depths=depths,
529
+ dims=dims,
530
+ drop_path_rate=drop_path_rate,
531
+ kernel_sizes=kernel_sizes,
532
+ )
533
+
534
+ self.head = HiFiGANGenerator(
535
+ hop_length=hop_length,
536
+ upsample_rates=upsample_rates,
537
+ upsample_kernel_sizes=upsample_kernel_sizes,
538
+ resblock_kernel_sizes=resblock_kernel_sizes,
539
+ resblock_dilation_sizes=resblock_dilation_sizes,
540
+ num_mels=num_mels,
541
+ upsample_initial_channel=upsample_initial_channel,
542
+ use_template=use_template,
543
+ pre_conv_kernel_size=pre_conv_kernel_size,
544
+ post_conv_kernel_size=post_conv_kernel_size,
545
+ )
546
+ self.sampling_rate = sampling_rate
547
+ self.mel_transform = LogMelSpectrogram(
548
+ sample_rate=sampling_rate,
549
+ n_fft=n_fft,
550
+ win_length=win_length,
551
+ hop_length=hop_length,
552
+ f_min=f_min,
553
+ f_max=f_max,
554
+ n_mels=n_mels,
555
+ )
556
+ self.eval()
557
+
558
+ @torch.no_grad()
559
+ def decode(self, mel):
560
+ y = self.backbone(mel)
561
+ y = self.head(y)
562
+ return y
563
+
564
+ @torch.no_grad()
565
+ def encode(self, x):
566
+ return self.mel_transform(x)
567
+
568
+ def forward(self, mel):
569
+ y = self.backbone(mel)
570
+ y = self.head(y)
571
+ return y
572
+
573
+
574
+ if __name__ == "__main__":
575
+ import soundfile as sf
576
+
577
+ x = "test_audio.wav"
578
+ model = ADaMoSHiFiGANV1.from_pretrained(
579
+ "./checkpoints/music_vocoder", local_files_only=True
580
+ )
581
+
582
+ wav, sr = librosa.load(x, sr=44100, mono=True)
583
+ wav = torch.from_numpy(wav).float()[None]
584
+ mel = model.encode(wav)
585
+
586
+ wav = model.decode(mel)[0].mT
587
+ sf.write("test_audio_vocoder_rec.wav", wav.cpu().numpy(), 44100)
model/ldm/__pycache__/attention.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
model/ldm/__pycache__/audioldm.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
model/ldm/__pycache__/customer_attention_processor.cpython-310.pyc ADDED
Binary file (9.44 kB). View file
 
model/ldm/__pycache__/dpm_solver_pytorch.cpython-310.pyc ADDED
Binary file (54.8 kB). View file
 
model/ldm/__pycache__/editing_unet.cpython-310.pyc ADDED
Binary file (1.59 kB). View file
 
model/ldm/__pycache__/linear_attention_block.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
model/ldm/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
model/ldm/attention.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from inspect import isfunction
7
+ import math
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn, einsum
11
+ from einops import rearrange, repeat
12
+ from diffusers.models.attention import Attention as DiffusersAttention
13
+ from diffusers.models.attention_processor import AttnProcessor2_0
14
+ from .customer_attention_processor import CustomLiteLACrossAttnProcessor2_0, CustomLiteLAProcessor2_0
15
+ class CheckpointFunction(torch.autograd.Function):
16
+ @staticmethod
17
+ def forward(ctx, run_function, length, *args):
18
+ ctx.run_function = run_function
19
+ ctx.input_tensors = list(args[:length])
20
+ ctx.input_params = list(args[length:])
21
+
22
+ with torch.no_grad():
23
+ output_tensors = ctx.run_function(*ctx.input_tensors)
24
+ return output_tensors
25
+
26
+ @staticmethod
27
+ def backward(ctx, *output_grads):
28
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
29
+ with torch.enable_grad():
30
+ # Fixes a bug where the first op in run_function modifies the
31
+ # Tensor storage in place, which is not allowed for detach()'d
32
+ # Tensors.
33
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
34
+ output_tensors = ctx.run_function(*shallow_copies)
35
+ input_grads = torch.autograd.grad(
36
+ output_tensors,
37
+ ctx.input_tensors + ctx.input_params,
38
+ output_grads,
39
+ allow_unused=True,
40
+ )
41
+ del ctx.input_tensors
42
+ del ctx.input_params
43
+ del output_tensors
44
+ return (None, None) + input_grads
45
+
46
+
47
+ def checkpoint(func, inputs, params, flag):
48
+ """
49
+ Evaluate a function without caching intermediate activations, allowing for
50
+ reduced memory at the expense of extra compute in the backward pass.
51
+ :param func: the function to evaluate.
52
+ :param inputs: the argument sequence to pass to `func`.
53
+ :param params: a sequence of parameters `func` depends on but does not
54
+ explicitly take as arguments.
55
+ :param flag: if False, disable gradient checkpointing.
56
+ """
57
+ if flag:
58
+ args = tuple(inputs) + tuple(params)
59
+ return CheckpointFunction.apply(func, len(inputs), *args)
60
+ else:
61
+ return func(*inputs)
62
+
63
+
64
+ def exists(val):
65
+ return val is not None
66
+
67
+
68
+ def uniq(arr):
69
+ return {el: True for el in arr}.keys()
70
+
71
+
72
+ def default(val, d):
73
+ if exists(val):
74
+ return val
75
+ return d() if isfunction(d) else d
76
+
77
+
78
+ def max_neg_value(t):
79
+ return -torch.finfo(t.dtype).max
80
+
81
+
82
+ def init_(tensor):
83
+ dim = tensor.shape[-1]
84
+ std = 1 / math.sqrt(dim)
85
+ tensor.uniform_(-std, std)
86
+ return tensor
87
+
88
+
89
+ # feedforward
90
+ class GEGLU(nn.Module):
91
+ def __init__(self, dim_in, dim_out):
92
+ super().__init__()
93
+ self.proj = nn.Linear(dim_in, dim_out * 2)
94
+
95
+ def forward(self, x):
96
+ x, gate = self.proj(x).chunk(2, dim=-1)
97
+ return x * F.gelu(gate)
98
+
99
+
100
+ class FeedForward(nn.Module):
101
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
102
+ super().__init__()
103
+ inner_dim = int(dim * mult)
104
+ dim_out = default(dim_out, dim)
105
+ project_in = (
106
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
107
+ if not glu
108
+ else GEGLU(dim, inner_dim)
109
+ )
110
+
111
+ self.net = nn.Sequential(
112
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
113
+ )
114
+
115
+ def forward(self, x):
116
+ return self.net(x)
117
+
118
+
119
+ def zero_module(module):
120
+ """
121
+ Zero out the parameters of a module and return it.
122
+ """
123
+ for p in module.parameters():
124
+ p.detach().zero_()
125
+ return module
126
+
127
+
128
+ def Normalize(in_channels):
129
+ return torch.nn.GroupNorm(
130
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
131
+ )
132
+
133
+
134
+ class LinearAttention(nn.Module):
135
+ def __init__(self, dim, heads=4, dim_head=32):
136
+ super().__init__()
137
+ self.heads = heads
138
+ hidden_dim = dim_head * heads
139
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
140
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
141
+
142
+ def forward(self, x):
143
+ b, c, h, w = x.shape
144
+ qkv = self.to_qkv(x)
145
+ q, k, v = rearrange(
146
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
147
+ )
148
+ k = k.softmax(dim=-1)
149
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
150
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
151
+ out = rearrange(
152
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
153
+ )
154
+ return self.to_out(out)
155
+
156
+
157
+ class SpatialSelfAttention(nn.Module):
158
+ def __init__(self, in_channels):
159
+ super().__init__()
160
+ self.in_channels = in_channels
161
+
162
+ self.norm = Normalize(in_channels)
163
+ self.q = torch.nn.Conv2d(
164
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
165
+ )
166
+ self.k = torch.nn.Conv2d(
167
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
168
+ )
169
+ self.v = torch.nn.Conv2d(
170
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
171
+ )
172
+ self.proj_out = torch.nn.Conv2d(
173
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
174
+ )
175
+
176
+ def forward(self, x):
177
+ h_ = x
178
+ h_ = self.norm(h_)
179
+ q = self.q(h_)
180
+ k = self.k(h_)
181
+ v = self.v(h_)
182
+
183
+ # compute attention
184
+ b, c, h, w = q.shape
185
+ q = rearrange(q, "b c h w -> b (h w) c")
186
+ k = rearrange(k, "b c h w -> b c (h w)")
187
+ w_ = torch.einsum("bij,bjk->bik", q, k)
188
+
189
+ w_ = w_ * (int(c) ** (-0.5))
190
+ w_ = torch.nn.functional.softmax(w_, dim=2)
191
+
192
+ # attend to values
193
+ v = rearrange(v, "b c h w -> b c (h w)")
194
+ w_ = rearrange(w_, "b i j -> b j i")
195
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
196
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
197
+ h_ = self.proj_out(h_)
198
+
199
+ return x + h_
200
+
201
+
202
+ class CrossAttention(nn.Module):
203
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
204
+ super().__init__()
205
+ inner_dim = dim_head * heads
206
+ context_dim = default(context_dim, query_dim)
207
+
208
+ self.scale = dim_head**-0.5
209
+ self.heads = heads
210
+
211
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
212
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
213
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
214
+
215
+ self.to_out = nn.Sequential(
216
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
217
+ )
218
+
219
+ def forward(self, x, context=None, mask=None):
220
+ h = self.heads
221
+
222
+ q = self.to_q(x)
223
+ context = default(context, x)
224
+ k = self.to_k(context)
225
+ v = self.to_v(context)
226
+
227
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
228
+
229
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
230
+
231
+ if exists(mask):
232
+ mask = rearrange(mask, "b ... -> b (...)")
233
+ max_neg_value = -torch.finfo(sim.dtype).max
234
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
235
+ sim.masked_fill_(~mask, max_neg_value)
236
+
237
+ # attention, what we cannot get enough of
238
+ attn = sim.softmax(dim=-1)
239
+
240
+ out = einsum("b i j, b j d -> b i d", attn, v)
241
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
242
+ return self.to_out(out)
243
+
244
+
245
+ class BasicTransformerBlock(nn.Module):
246
+ def __init__(
247
+ self,
248
+ dim,
249
+ n_heads,
250
+ d_head,
251
+ dropout=0.0,
252
+ context_dim=None,
253
+ gated_ff=True,
254
+ checkpoint=True,
255
+ ):
256
+ super().__init__()
257
+
258
+ # UNet BasicTransformerBlock with Linear Attention for both Self and Cross attention
259
+
260
+ # 1. Self-Attention with Linear Attention for efficiency
261
+ self.attn1 = DiffusersAttention(
262
+ query_dim=dim,
263
+ heads=n_heads,
264
+ dim_head=d_head,
265
+ dropout=dropout,
266
+ processor=CustomLiteLAProcessor2_0() # Linear attention for self-attention
267
+ )
268
+
269
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
270
+
271
+ # 2. Cross-Attention with Standard Attention for optimal text conditioning
272
+ # Using AttnProcessor2_0 for better text-audio alignment and conditioning quality
273
+ self.attn2 = DiffusersAttention(
274
+ query_dim=dim,
275
+ cross_attention_dim=context_dim,
276
+ heads=n_heads,
277
+ dim_head=d_head,
278
+ dropout=dropout,
279
+ processor=AttnProcessor2_0() # Standard attention for best cross-attention performance
280
+ )
281
+
282
+ self.norm1 = nn.LayerNorm(dim)
283
+ self.norm2 = nn.LayerNorm(dim)
284
+ self.norm3 = nn.LayerNorm(dim)
285
+ self.checkpoint = checkpoint
286
+
287
+ def forward(self, x, context=None):
288
+ # Hàm checkpoint tùy chỉnh của Amphion có thể không tương thích tốt
289
+ # Hãy sử dụng checkpoint của PyTorch nếu cần, nhưng để đơn giản, ta tạm bỏ qua
290
+ # return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
291
+ return self._forward(x, context)
292
+
293
+ def _forward(self, x, context=None):
294
+ # 1. Self-Attention
295
+ # Lớp của Diffusers trả về tensor trực tiếp, không phải tuple
296
+ out1, _ = self.attn1(self.norm1(x))
297
+ x = out1 + x
298
+
299
+ # 2. Cross-Attention
300
+ #out2, _ = self.attn2(self.norm2(x), encoder_hidden_states=context)
301
+ x = self.attn2(self.norm2(x), encoder_hidden_states=context) + x
302
+
303
+ # 3. Feed-forward
304
+ x = self.ff(self.norm3(x)) + x
305
+ return x
306
+
307
+
308
+
309
+
310
+ class SpatialTransformer(nn.Module):
311
+ """
312
+ Transformer block for image-like data.
313
+ First, project the input (aka embedding)
314
+ and reshape to b, t, d.
315
+ Then apply standard transformer action.
316
+ Finally, reshape to image
317
+ """
318
+
319
+ def __init__(
320
+ self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
321
+ ):
322
+ super().__init__()
323
+ self.in_channels = in_channels
324
+ inner_dim = n_heads * d_head
325
+ self.norm = Normalize(in_channels)
326
+
327
+ self.proj_in = nn.Conv2d(
328
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
329
+ )
330
+
331
+ self.transformer_blocks = nn.ModuleList(
332
+ [
333
+ BasicTransformerBlock(
334
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
335
+ )
336
+ for d in range(depth)
337
+ ]
338
+ )
339
+
340
+ self.proj_out = zero_module(
341
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
342
+ )
343
+
344
+ def forward(self, x, context=None):
345
+ # note: if no context is given, cross-attention defaults to self-attention
346
+ b, c, h, w = x.shape
347
+ x_in = x
348
+ x = self.norm(x)
349
+ x = self.proj_in(x)
350
+ x = rearrange(x, "b c h w -> b (h w) c")
351
+ for block in self.transformer_blocks:
352
+ x = block(x, context=context)
353
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
354
+ x = self.proj_out(x)
355
+ return x + x_in
model/ldm/audioldm.py ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import abstractmethod
7
+ from functools import partial
8
+ import math
9
+ from typing import Iterable
10
+
11
+ import os
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ from einops import repeat
17
+ from torch.utils.checkpoint import checkpoint as pt_checkpoint
18
+
19
+ from .attention import SpatialTransformer
20
+
21
+ # from attention import SpatialTransformer
22
+
23
+
24
+ class CheckpointFunction(torch.autograd.Function):
25
+ @staticmethod
26
+ def forward(ctx, run_function, length, *args):
27
+ ctx.run_function = run_function
28
+ ctx.input_tensors = list(args[:length])
29
+ ctx.input_params = list(args[length:])
30
+
31
+ with torch.no_grad():
32
+ output_tensors = ctx.run_function(*ctx.input_tensors)
33
+ return output_tensors
34
+
35
+ @staticmethod
36
+ def backward(ctx, *output_grads):
37
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
38
+ with torch.enable_grad():
39
+ # Fixes a bug where the first op in run_function modifies the
40
+ # Tensor storage in place, which is not allowed for detach()'d
41
+ # Tensors.
42
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
43
+ output_tensors = ctx.run_function(*shallow_copies)
44
+ input_grads = torch.autograd.grad(
45
+ output_tensors,
46
+ ctx.input_tensors + ctx.input_params,
47
+ output_grads,
48
+ allow_unused=True,
49
+ )
50
+ del ctx.input_tensors
51
+ del ctx.input_params
52
+ del output_tensors
53
+ return (None, None) + input_grads
54
+
55
+
56
+ def checkpoint(func, inputs, params, flag):
57
+ """
58
+ Evaluate a function without caching intermediate activations, allowing for
59
+ reduced memory at the expense of extra compute in the backward pass.
60
+ :param func: the function to evaluate.
61
+ :param inputs: the argument sequence to pass to `func`.
62
+ :param params: a sequence of parameters `func` depends on but does not
63
+ explicitly take as arguments.
64
+ :param flag: if False, disable gradient checkpointing.
65
+ """
66
+ if flag:
67
+ args = tuple(inputs) + tuple(params)
68
+ return CheckpointFunction.apply(func, len(inputs), *args)
69
+ else:
70
+ return func(*inputs)
71
+
72
+
73
+ def zero_module(module):
74
+ """
75
+ Zero out the parameters of a module and return it.
76
+ """
77
+ for p in module.parameters():
78
+ p.detach().zero_()
79
+ return module
80
+
81
+
82
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
83
+ """
84
+ Create sinusoidal timestep embeddings.
85
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
86
+ These may be fractional.
87
+ :param dim: the dimension of the output.
88
+ :param max_period: controls the minimum frequency of the embeddings.
89
+ :return: an [N x dim] Tensor of positional embeddings.
90
+ """
91
+ if not repeat_only:
92
+ half = dim // 2
93
+ freqs = torch.exp(
94
+ -math.log(max_period)
95
+ * torch.arange(start=0, end=half, dtype=torch.float32)
96
+ / half
97
+ ).to(device=timesteps.device)
98
+ args = timesteps[:, None].float() * freqs[None]
99
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
100
+ if dim % 2:
101
+ embedding = torch.cat(
102
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
103
+ )
104
+ else:
105
+ embedding = repeat(timesteps, "b -> b d", d=dim)
106
+ return embedding
107
+
108
+
109
+ class GroupNorm32(nn.GroupNorm):
110
+ def forward(self, x):
111
+ # Lấy dtype mục tiêu từ chính tham số của lớp này
112
+ # Điều này đảm bảo input và weight/bias luôn có cùng dtype
113
+ target_dtype = self.weight.dtype
114
+
115
+ # Chuyển input sang đúng dtype và thực hiện phép toán
116
+ return F.group_norm(
117
+ x.to(target_dtype),
118
+ self.num_groups,
119
+ self.weight,
120
+ self.bias,
121
+ self.eps
122
+ )
123
+
124
+
125
+ def normalization(channels):
126
+ """
127
+ Make a standard normalization layer.
128
+ :param channels: number of input channels.
129
+ :return: an nn.Module for normalization.
130
+ """
131
+ return GroupNorm32(32, channels)
132
+
133
+
134
+ def count_flops_attn(model, _x, y):
135
+ """
136
+ A counter for the `thop` package to count the operations in an
137
+ attention operation.
138
+ Meant to be used like:
139
+ macs, params = thop.profile(
140
+ model,
141
+ inputs=(inputs, timestamps),
142
+ custom_ops={QKVAttention: QKVAttention.count_flops},
143
+ )
144
+ """
145
+ b, c, *spatial = y[0].shape
146
+ num_spatial = int(np.prod(spatial))
147
+ # We perform two matmuls with the same number of ops.
148
+ # The first computes the weight matrix, the second computes
149
+ # the combination of the value vectors.
150
+ matmul_ops = 2 * b * (num_spatial**2) * c
151
+ model.total_ops += torch.DoubleTensor([matmul_ops])
152
+
153
+
154
+ def conv_nd(dims, *args, **kwargs):
155
+ """
156
+ Create a 1D, 2D, or 3D convolution module.
157
+ """
158
+ if dims == 1:
159
+ return nn.Conv1d(*args, **kwargs)
160
+ elif dims == 2:
161
+ return nn.Conv2d(*args, **kwargs)
162
+ elif dims == 3:
163
+ return nn.Conv3d(*args, **kwargs)
164
+ raise ValueError(f"unsupported dimensions: {dims}")
165
+
166
+
167
+ def avg_pool_nd(dims, *args, **kwargs):
168
+ """
169
+ Create a 1D, 2D, or 3D average pooling module.
170
+ """
171
+ if dims == 1:
172
+ return nn.AvgPool1d(*args, **kwargs)
173
+ elif dims == 2:
174
+ return nn.AvgPool2d(*args, **kwargs)
175
+ elif dims == 3:
176
+ return nn.AvgPool3d(*args, **kwargs)
177
+ raise ValueError(f"unsupported dimensions: {dims}")
178
+
179
+
180
+ class QKVAttention(nn.Module):
181
+ """
182
+ A module which performs QKV attention and splits in a different order.
183
+ """
184
+
185
+ def __init__(self, n_heads):
186
+ super().__init__()
187
+ self.n_heads = n_heads
188
+
189
+ def forward(self, qkv):
190
+ """
191
+ Apply QKV attention.
192
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
193
+ :return: an [N x (H * C) x T] tensor after attention.
194
+ """
195
+
196
+ bs, width, length = qkv.shape
197
+ assert width % (3 * self.n_heads) == 0
198
+ ch = width // (3 * self.n_heads)
199
+ q, k, v = qkv.chunk(3, dim=1) # [N x (H * C) x T]
200
+ scale = 1 / math.sqrt(math.sqrt(ch))
201
+ weight = torch.einsum(
202
+ "bct,bcs->bts",
203
+ (q * scale).view(bs * self.n_heads, ch, length),
204
+ (k * scale).view(bs * self.n_heads, ch, length),
205
+ ) # More stable with f16 than dividing afterwards
206
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
207
+ a = torch.einsum(
208
+ "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
209
+ )
210
+ return a.reshape(bs, -1, length)
211
+
212
+ @staticmethod
213
+ def count_flops(model, _x, y):
214
+ return count_flops_attn(model, _x, y)
215
+
216
+
217
+ class QKVAttentionLegacy(nn.Module):
218
+ """
219
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
220
+ """
221
+
222
+ def __init__(self, n_heads):
223
+ super().__init__()
224
+ self.n_heads = n_heads
225
+
226
+ def forward(self, qkv):
227
+ """
228
+ Apply QKV attention.
229
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
230
+ :return: an [N x (H * C) x T] tensor after attention.
231
+ """
232
+ bs, width, length = qkv.shape
233
+ assert width % (3 * self.n_heads) == 0
234
+ ch = width // (3 * self.n_heads)
235
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
236
+ scale = 1 / math.sqrt(math.sqrt(ch))
237
+ weight = torch.einsum(
238
+ "bct,bcs->bts", q * scale, k * scale
239
+ ) # More stable with f16 than dividing afterwards
240
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
241
+ a = torch.einsum("bts,bcs->bct", weight, v)
242
+ return a.reshape(bs, -1, length)
243
+
244
+ @staticmethod
245
+ def count_flops(model, _x, y):
246
+ return count_flops_attn(model, _x, y)
247
+
248
+
249
+ class AttentionPool2d(nn.Module):
250
+ """
251
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ spacial_dim: int,
257
+ embed_dim: int,
258
+ num_heads_channels: int,
259
+ output_dim: int = None,
260
+ ):
261
+ super().__init__()
262
+ self.positional_embedding = nn.Parameter(
263
+ torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
264
+ )
265
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
266
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
267
+ self.num_heads = embed_dim // num_heads_channels
268
+ self.attention = QKVAttention(self.num_heads)
269
+
270
+ def forward(self, x):
271
+ b, c, *_spatial = x.shape
272
+ x = x.reshape(b, c, -1) # NC(HW)
273
+ x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
274
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
275
+ x = self.qkv_proj(x)
276
+ x = self.attention(x)
277
+ x = self.c_proj(x)
278
+ return x[:, :, 0]
279
+
280
+
281
+ class TimestepBlock(nn.Module):
282
+ """
283
+ Any module where forward() takes timestep embeddings as a second argument.
284
+ """
285
+
286
+ @abstractmethod
287
+ def forward(self, x, emb):
288
+ """
289
+ Apply the module to `x` given `emb` timestep embeddings.
290
+ """
291
+
292
+
293
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
294
+ """
295
+ A sequential module that passes timestep embeddings to the children that
296
+ support it as an extra input.
297
+ """
298
+
299
+ def forward(self, x, emb, context=None):
300
+ for layer in self:
301
+ if isinstance(layer, TimestepBlock):
302
+ x = layer(x, emb)
303
+ elif isinstance(layer, SpatialTransformer):
304
+ x = layer(x, context)
305
+ else:
306
+ x = layer(x)
307
+ return x
308
+
309
+
310
+ class Upsample(nn.Module):
311
+ """
312
+ An upsampling layer with an optional convolution.
313
+ :param channels: channels in the inputs and outputs.
314
+ :param use_conv: a bool determining if a convolution is applied.
315
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
316
+ upsampling occurs in the inner-two dimensions.
317
+ """
318
+
319
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
320
+ super().__init__()
321
+ self.channels = channels
322
+ self.out_channels = out_channels or channels
323
+ self.use_conv = use_conv
324
+ self.dims = dims
325
+ if use_conv:
326
+ self.conv = conv_nd(
327
+ dims, self.channels, self.out_channels, 3, padding=padding
328
+ )
329
+
330
+ def forward(self, x):
331
+ assert x.shape[1] == self.channels
332
+ if self.dims == 3:
333
+ x = F.interpolate(
334
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
335
+ )
336
+ else:
337
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
338
+ if self.use_conv:
339
+ x = self.conv(x)
340
+ return x
341
+
342
+
343
+ class TransposedUpsample(nn.Module):
344
+ "Learned 2x upsampling without padding"
345
+
346
+ def __init__(self, channels, out_channels=None, ks=5):
347
+ super().__init__()
348
+ self.channels = channels
349
+ self.out_channels = out_channels or channels
350
+
351
+ self.up = nn.ConvTranspose2d(
352
+ self.channels, self.out_channels, kernel_size=ks, stride=2
353
+ )
354
+
355
+ def forward(self, x):
356
+ return self.up(x)
357
+
358
+
359
+ class Downsample(nn.Module):
360
+ """
361
+ A downsampling layer with an optional convolution.
362
+ :param channels: channels in the inputs and outputs.
363
+ :param use_conv: a bool determining if a convolution is applied.
364
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
365
+ downsampling occurs in the inner-two dimensions.
366
+ """
367
+
368
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
369
+ super().__init__()
370
+ self.channels = channels
371
+ self.out_channels = out_channels or channels
372
+ self.use_conv = use_conv
373
+ self.dims = dims
374
+ stride = 2 if dims != 3 else (1, 2, 2)
375
+ if use_conv:
376
+ self.op = conv_nd(
377
+ dims,
378
+ self.channels,
379
+ self.out_channels,
380
+ 3,
381
+ stride=stride,
382
+ padding=padding,
383
+ )
384
+ else:
385
+ assert self.channels == self.out_channels
386
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
387
+
388
+ def forward(self, x):
389
+ assert x.shape[1] == self.channels
390
+ return self.op(x)
391
+
392
+
393
+ class ResBlock(TimestepBlock):
394
+ """
395
+ A residual block that can optionally change the number of channels.
396
+ :param channels: the number of input channels.
397
+ :param emb_channels: the number of timestep embedding channels.
398
+ :param dropout: the rate of dropout.
399
+ :param out_channels: if specified, the number of out channels.
400
+ :param use_conv: if True and out_channels is specified, use a spatial
401
+ convolution instead of a smaller 1x1 convolution to change the
402
+ channels in the skip connection.
403
+ :param dims: determines if the signal is 1D, 2D, or 3D.
404
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
405
+ :param up: if True, use this block for upsampling.
406
+ :param down: if True, use this block for downsampling.
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ channels,
412
+ emb_channels,
413
+ dropout,
414
+ out_channels=None,
415
+ use_conv=False,
416
+ use_scale_shift_norm=False,
417
+ dims=2,
418
+ use_checkpoint=False,
419
+ up=False,
420
+ down=False,
421
+ ):
422
+ super().__init__()
423
+ self.channels = channels
424
+ self.emb_channels = emb_channels
425
+ self.dropout = dropout
426
+ self.out_channels = out_channels or channels
427
+ self.use_conv = use_conv
428
+ self.use_checkpoint = use_checkpoint
429
+ self.use_scale_shift_norm = use_scale_shift_norm
430
+
431
+ self.in_layers = nn.Sequential(
432
+ normalization(channels),
433
+ nn.SiLU(),
434
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
435
+ )
436
+
437
+ self.updown = up or down
438
+
439
+ if up:
440
+ self.h_upd = Upsample(channels, False, dims)
441
+ self.x_upd = Upsample(channels, False, dims)
442
+ elif down:
443
+ self.h_upd = Downsample(channels, False, dims)
444
+ self.x_upd = Downsample(channels, False, dims)
445
+ else:
446
+ self.h_upd = self.x_upd = nn.Identity()
447
+
448
+ self.emb_layers = nn.Sequential(
449
+ nn.SiLU(),
450
+ nn.Linear(
451
+ emb_channels,
452
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
453
+ ),
454
+ )
455
+ self.out_layers = nn.Sequential(
456
+ normalization(self.out_channels),
457
+ nn.SiLU(),
458
+ nn.Dropout(p=dropout),
459
+ zero_module(
460
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
461
+ ),
462
+ )
463
+
464
+ if self.out_channels == channels:
465
+ self.skip_connection = nn.Identity()
466
+ elif use_conv:
467
+ self.skip_connection = conv_nd(
468
+ dims, channels, self.out_channels, 3, padding=1
469
+ )
470
+ else:
471
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
472
+
473
+ def forward(self, x, emb):
474
+ """
475
+ Apply the block to a Tensor, conditioned on a timestep embedding.
476
+ :param x: an [N x C x ...] Tensor of features.
477
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
478
+ :return: an [N x C x ...] Tensor of outputs.
479
+ """
480
+ if self.use_checkpoint:
481
+ # Use PyTorch's native checkpointing
482
+ return pt_checkpoint(self._forward, x, emb, use_reentrant=False)
483
+ else:
484
+ return self._forward(x, emb)
485
+
486
+
487
+ def _forward(self, x, emb):
488
+ if self.updown:
489
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
490
+ h = in_rest(x)
491
+ h = self.h_upd(h)
492
+ x = self.x_upd(x)
493
+ h = in_conv(h)
494
+ else:
495
+ h = self.in_layers(x)
496
+ emb_out = self.emb_layers(emb).type(h.dtype)
497
+ while len(emb_out.shape) < len(h.shape):
498
+ emb_out = emb_out[..., None]
499
+ if self.use_scale_shift_norm:
500
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
501
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
502
+ h = out_norm(h) * (1 + scale) + shift
503
+ h = out_rest(h)
504
+ else:
505
+ h = h + emb_out
506
+ h = self.out_layers(h)
507
+ return self.skip_connection(x) + h
508
+
509
+
510
+ class AttentionBlock(nn.Module):
511
+ """
512
+ An attention block that allows spatial positions to attend to each other.
513
+ Originally ported from here, but adapted to the N-d case.
514
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
515
+ """
516
+
517
+ def __init__(
518
+ self,
519
+ channels,
520
+ num_heads=1,
521
+ num_head_channels=-1,
522
+ use_checkpoint=False,
523
+ use_new_attention_order=False,
524
+ ):
525
+ super().__init__()
526
+ self.channels = channels
527
+ if num_head_channels == -1:
528
+ self.num_heads = num_heads
529
+ else:
530
+ assert (
531
+ channels % num_head_channels == 0
532
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
533
+ self.num_heads = channels // num_head_channels
534
+ self.use_checkpoint = use_checkpoint
535
+ self.norm = normalization(channels)
536
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
537
+ if use_new_attention_order:
538
+ # split qkv before split heads
539
+ self.attention = QKVAttention(self.num_heads)
540
+ else:
541
+ # split heads before split qkv
542
+ self.attention = QKVAttentionLegacy(self.num_heads)
543
+
544
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
545
+
546
+ def forward(self, x):
547
+ if self.use_checkpoint:
548
+ # Use PyTorch's native checkpointing
549
+ return pt_checkpoint(self._forward, x, use_reentrant=False)
550
+ else:
551
+ return self._forward(x)
552
+
553
+
554
+ def _forward(self, x):
555
+ b, c, *spatial = x.shape
556
+ x = x.reshape(b, c, -1)
557
+ qkv = self.qkv(self.norm(x))
558
+ h = self.attention(qkv)
559
+ h = self.proj_out(h)
560
+ return (x + h).reshape(b, c, *spatial)
561
+
562
+
563
+ class UNetModel(nn.Module):
564
+ """
565
+ The full UNet model with attention and timestep embedding.
566
+ :param in_channels: channels in the input Tensor.
567
+ :param model_channels: base channel count for the model.
568
+ :param out_channels: channels in the output Tensor.
569
+ :param num_res_blocks: number of residual blocks per downsample.
570
+ :param attention_resolutions: a collection of downsample rates at which
571
+ attention will take place. May be a set, list, or tuple.
572
+ For example, if this contains 4, then at 4x downsampling, attention
573
+ will be used.
574
+ :param dropout: the dropout probability.
575
+ :param channel_mult: channel multiplier for each level of the UNet.
576
+ :param conv_resample: if True, use learned convolutions for upsampling and
577
+ downsampling.
578
+ :param dims: determines if the signal is 1D, 2D, or 3D.
579
+ :param num_classes: if specified (as an int), then this model will be
580
+ class-conditional with `num_classes` classes.
581
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
582
+ :param num_heads: the number of attention heads in each attention layer.
583
+ :param num_heads_channels: if specified, ignore num_heads and instead use
584
+ a fixed channel width per attention head.
585
+ :param num_heads_upsample: works with num_heads to set a different number
586
+ of heads for upsampling. Deprecated.
587
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
588
+ :param resblock_updown: use residual blocks for up/downsampling.
589
+ :param use_new_attention_order: use a different attention pattern for potentially
590
+ increased efficiency.
591
+ """
592
+
593
+ def __init__(
594
+ self,
595
+ image_size,
596
+ in_channels,
597
+ model_channels,
598
+ out_channels,
599
+ num_res_blocks,
600
+ attention_resolutions,
601
+ dropout=0,
602
+ channel_mult=(1, 2, 4, 8),
603
+ conv_resample=True,
604
+ dims=2,
605
+ num_classes=None,
606
+ use_checkpoint=False,
607
+ use_fp16=True,
608
+ num_heads=-1,
609
+ num_head_channels=-1,
610
+ num_heads_upsample=-1,
611
+ use_scale_shift_norm=False,
612
+ resblock_updown=False,
613
+ use_new_attention_order=False,
614
+ use_spatial_transformer=False, # custom transformer support
615
+ transformer_depth=1, # custom transformer support
616
+ context_dim=None, # custom transformer support
617
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
618
+ legacy=True,
619
+ ):
620
+ super().__init__()
621
+ if use_spatial_transformer:
622
+ assert (
623
+ context_dim is not None
624
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
625
+
626
+ if context_dim is not None:
627
+ assert (
628
+ use_spatial_transformer
629
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
630
+ from omegaconf.listconfig import ListConfig
631
+
632
+ if type(context_dim) == ListConfig:
633
+ context_dim = list(context_dim)
634
+
635
+ if num_heads_upsample == -1:
636
+ num_heads_upsample = num_heads
637
+
638
+ if num_heads == -1:
639
+ assert (
640
+ num_head_channels != -1
641
+ ), "Either num_heads or num_head_channels has to be set"
642
+
643
+ if num_head_channels == -1:
644
+ assert (
645
+ num_heads != -1
646
+ ), "Either num_heads or num_head_channels has to be set"
647
+
648
+ self.image_size = image_size
649
+ self.in_channels = in_channels
650
+ self.model_channels = model_channels
651
+ self.out_channels = out_channels
652
+ self.num_res_blocks = num_res_blocks
653
+ self.attention_resolutions = attention_resolutions
654
+ self.dropout = dropout
655
+ self.channel_mult = channel_mult
656
+ self.conv_resample = conv_resample
657
+ self.num_classes = num_classes
658
+ self.use_checkpoint = use_checkpoint
659
+ #self.dtype = torch.float16 if use_fp16 else torch.float32
660
+ self.num_heads = num_heads
661
+ self.num_head_channels = num_head_channels
662
+ self.num_heads_upsample = num_heads_upsample
663
+ self.predict_codebook_ids = n_embed is not None
664
+
665
+ time_embed_dim = model_channels * 4
666
+ self.time_embed = nn.Sequential(
667
+ nn.Linear(model_channels, time_embed_dim),
668
+ nn.SiLU(),
669
+ nn.Linear(time_embed_dim, time_embed_dim),
670
+ )
671
+
672
+ if self.num_classes is not None:
673
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
674
+
675
+ self.input_blocks = nn.ModuleList(
676
+ [
677
+ TimestepEmbedSequential(
678
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
679
+ )
680
+ ]
681
+ )
682
+ self._feature_size = model_channels
683
+ input_block_chans = [model_channels]
684
+ ch = model_channels
685
+ ds = 1
686
+ for level, mult in enumerate(channel_mult):
687
+ for _ in range(num_res_blocks):
688
+ layers = [
689
+ ResBlock(
690
+ ch,
691
+ time_embed_dim,
692
+ dropout,
693
+ out_channels=mult * model_channels,
694
+ dims=dims,
695
+ use_checkpoint=use_checkpoint,
696
+ use_scale_shift_norm=use_scale_shift_norm,
697
+ )
698
+ ]
699
+ ch = mult * model_channels
700
+ if ds in attention_resolutions:
701
+ if num_head_channels == -1:
702
+ dim_head = ch // num_heads
703
+ else:
704
+ num_heads = ch // num_head_channels
705
+ dim_head = num_head_channels
706
+ if legacy:
707
+ # num_heads = 1
708
+ dim_head = (
709
+ ch // num_heads
710
+ if use_spatial_transformer
711
+ else num_head_channels
712
+ )
713
+ layers.append(
714
+ AttentionBlock(
715
+ ch,
716
+ use_checkpoint=use_checkpoint,
717
+ num_heads=num_heads,
718
+ num_head_channels=dim_head,
719
+ use_new_attention_order=use_new_attention_order,
720
+ )
721
+ if not use_spatial_transformer
722
+ else SpatialTransformer(
723
+ ch,
724
+ num_heads,
725
+ dim_head,
726
+ depth=transformer_depth,
727
+ context_dim=context_dim,
728
+ )
729
+ )
730
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
731
+ self._feature_size += ch
732
+ input_block_chans.append(ch)
733
+ if level != len(channel_mult) - 1:
734
+ out_ch = ch
735
+ self.input_blocks.append(
736
+ TimestepEmbedSequential(
737
+ ResBlock(
738
+ ch,
739
+ time_embed_dim,
740
+ dropout,
741
+ out_channels=out_ch,
742
+ dims=dims,
743
+ use_checkpoint=use_checkpoint,
744
+ use_scale_shift_norm=use_scale_shift_norm,
745
+ down=True,
746
+ )
747
+ if resblock_updown
748
+ else Downsample(
749
+ ch, conv_resample, dims=dims, out_channels=out_ch
750
+ )
751
+ )
752
+ )
753
+ ch = out_ch
754
+ input_block_chans.append(ch)
755
+ ds *= 2
756
+ self._feature_size += ch
757
+
758
+ if num_head_channels == -1:
759
+ dim_head = ch // num_heads
760
+ else:
761
+ num_heads = ch // num_head_channels
762
+ dim_head = num_head_channels
763
+ if legacy:
764
+ # num_heads = 1
765
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
766
+ self.middle_block = TimestepEmbedSequential(
767
+ ResBlock(
768
+ ch,
769
+ time_embed_dim,
770
+ dropout,
771
+ dims=dims,
772
+ use_checkpoint=use_checkpoint,
773
+ use_scale_shift_norm=use_scale_shift_norm,
774
+ ),
775
+ (
776
+ AttentionBlock(
777
+ ch,
778
+ use_checkpoint=use_checkpoint,
779
+ num_heads=num_heads,
780
+ num_head_channels=dim_head,
781
+ use_new_attention_order=use_new_attention_order,
782
+ )
783
+ if not use_spatial_transformer
784
+ else SpatialTransformer(
785
+ ch,
786
+ num_heads,
787
+ dim_head,
788
+ depth=transformer_depth,
789
+ context_dim=context_dim,
790
+ )
791
+ ),
792
+ ResBlock(
793
+ ch,
794
+ time_embed_dim,
795
+ dropout,
796
+ dims=dims,
797
+ use_checkpoint=use_checkpoint,
798
+ use_scale_shift_norm=use_scale_shift_norm,
799
+ ),
800
+ )
801
+ self._feature_size += ch
802
+
803
+ self.output_blocks = nn.ModuleList([])
804
+ for level, mult in list(enumerate(channel_mult))[::-1]:
805
+ for i in range(num_res_blocks + 1):
806
+ ich = input_block_chans.pop()
807
+ layers = [
808
+ ResBlock(
809
+ ch + ich,
810
+ time_embed_dim,
811
+ dropout,
812
+ out_channels=model_channels * mult,
813
+ dims=dims,
814
+ use_checkpoint=use_checkpoint,
815
+ use_scale_shift_norm=use_scale_shift_norm,
816
+ )
817
+ ]
818
+ ch = model_channels * mult
819
+ if ds in attention_resolutions:
820
+ if num_head_channels == -1:
821
+ dim_head = ch // num_heads
822
+ else:
823
+ num_heads = ch // num_head_channels
824
+ dim_head = num_head_channels
825
+ if legacy:
826
+ # num_heads = 1
827
+ dim_head = (
828
+ ch // num_heads
829
+ if use_spatial_transformer
830
+ else num_head_channels
831
+ )
832
+ layers.append(
833
+ AttentionBlock(
834
+ ch,
835
+ use_checkpoint=use_checkpoint,
836
+ num_heads=num_heads_upsample,
837
+ num_head_channels=dim_head,
838
+ use_new_attention_order=use_new_attention_order,
839
+ )
840
+ if not use_spatial_transformer
841
+ else SpatialTransformer(
842
+ ch,
843
+ num_heads,
844
+ dim_head,
845
+ depth=transformer_depth,
846
+ context_dim=context_dim,
847
+ )
848
+ )
849
+ if level and i == num_res_blocks:
850
+ out_ch = ch
851
+ layers.append(
852
+ ResBlock(
853
+ ch,
854
+ time_embed_dim,
855
+ dropout,
856
+ out_channels=out_ch,
857
+ dims=dims,
858
+ use_checkpoint=use_checkpoint,
859
+ use_scale_shift_norm=use_scale_shift_norm,
860
+ up=True,
861
+ )
862
+ if resblock_updown
863
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
864
+ )
865
+ ds //= 2
866
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
867
+ self._feature_size += ch
868
+
869
+ self.out = nn.Sequential(
870
+ normalization(ch),
871
+ nn.SiLU(),
872
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
873
+ )
874
+ if self.predict_codebook_ids:
875
+ self.id_predictor = nn.Sequential(
876
+ normalization(ch),
877
+ conv_nd(dims, model_channels, n_embed, 1),
878
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
879
+ )
880
+
881
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
882
+ """
883
+ Apply the model to an input batch.
884
+ :param x: an [N x C x ...] Tensor of inputs.
885
+ :param timesteps: a 1-D batch of timesteps.
886
+ :param context: conditioning plugged in via crossattn
887
+ :param y: an [N] Tensor of labels, if class-conditional.
888
+ :return: an [N x C x ...] Tensor of outputs.
889
+ """
890
+ assert (y is not None) == (
891
+ self.num_classes is not None
892
+ ), "must specify y if and only if the model is class-conditional"
893
+ hs = []
894
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
895
+ # Ensure t_emb matches the dtype of time_embed layer weights
896
+ emb = self.time_embed(t_emb.to(self.time_embed[0].weight.dtype))
897
+
898
+ if self.num_classes is not None:
899
+ assert y.shape == (x.shape[0],)
900
+ emb = emb + self.label_emb(y)
901
+
902
+ h = x#.type(self.dtype)
903
+ for module in self.input_blocks:
904
+ h = module(h, emb, context)
905
+ hs.append(h)
906
+ h = self.middle_block(h, emb, context)
907
+ for module in self.output_blocks:
908
+ # print(h.shape, hs[-1].shape)
909
+ if h.shape != hs[-1].shape:
910
+ if h.shape[-1] > hs[-1].shape[-1]:
911
+ h = h[:, :, :, : hs[-1].shape[-1]]
912
+ if h.shape[-2] > hs[-1].shape[-2]:
913
+ h = h[:, :, : hs[-1].shape[-2], :]
914
+ h = torch.cat([h, hs.pop()], dim=1)
915
+ h = module(h, emb, context)
916
+ # print(h.shape)
917
+ #h = h.type(x.dtype)
918
+ if self.predict_codebook_ids:
919
+ return self.id_predictor(h)
920
+ else:
921
+ return self.out(h)
922
+
923
+
924
+ class AudioLDM(nn.Module):
925
+ def __init__(self, cfg):
926
+ super().__init__()
927
+ self.cfg = cfg
928
+ self.unet = UNetModel(
929
+ image_size=cfg.image_size,
930
+ in_channels=cfg.in_channels,
931
+ out_channels=cfg.out_channels,
932
+ model_channels=cfg.model_channels,
933
+ attention_resolutions=cfg.attention_resolutions,
934
+ num_res_blocks=cfg.num_res_blocks,
935
+ channel_mult=cfg.channel_mult,
936
+ num_heads=cfg.num_heads,
937
+ use_spatial_transformer=cfg.use_spatial_transformer,
938
+ transformer_depth=cfg.transformer_depth,
939
+ context_dim=cfg.context_dim,
940
+ use_checkpoint=cfg.use_checkpoint,
941
+ legacy=cfg.legacy,
942
+ )
943
+
944
+ def forward(self, x, timesteps=None, context=None, y=None):
945
+ x = self.unet(x=x, timesteps=timesteps, context=context, y=y)
946
+ return x
model/ldm/customer_attention_processor.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Union, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import logging
21
+ from diffusers.models.attention_processor import Attention
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+ # ADD THIS NEW CLASS to the end of customer_attention_processor.py
26
+
27
+ class CustomLiteLACrossAttnProcessor2_0:
28
+ """
29
+ Attention processor for LINEAR CROSS-ATTENTION.
30
+ This correctly uses the `encoder_hidden_states` for keys and values.
31
+ """
32
+ def __init__(self):
33
+ self.kernel_func = nn.ReLU(inplace=False)
34
+ self.eps = 1e-15
35
+ self.pad_val = 1.0
36
+
37
+ # The apply_rotary_emb function is identical, you can copy it from above if needed
38
+ def apply_rotary_emb(self, x, freqs_cis):
39
+ cos, sin = freqs_cis
40
+ cos, sin = cos[None, None].to(x.device), sin[None, None].to(x.device)
41
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
42
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
43
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
44
+
45
+ def __call__(
46
+ self,
47
+ attn: Attention,
48
+ hidden_states: torch.FloatTensor,
49
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
50
+ attention_mask: Optional[torch.FloatTensor] = None,
51
+ rotary_freqs_cis: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None,
52
+ # Add other args for compatibility
53
+ **kwargs,
54
+ ) -> torch.FloatTensor:
55
+
56
+ input_ndim = hidden_states.ndim
57
+ if input_ndim == 4:
58
+ batch_size, channel, height, width = hidden_states.shape
59
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
60
+
61
+ batch_size = hidden_states.shape[0]
62
+
63
+ # --- KEY FIX IS HERE ---
64
+ # Q is from audio, K and V are from text
65
+ query = attn.to_q(hidden_states)
66
+
67
+ # Use encoder_hidden_states for K and V
68
+ if encoder_hidden_states is None:
69
+ encoder_hidden_states = hidden_states # Fallback to self-attention
70
+
71
+ key = attn.to_k(encoder_hidden_states)
72
+ value = attn.to_v(encoder_hidden_states)
73
+ # --- END OF FIX ---
74
+
75
+ inner_dim = key.shape[-1]
76
+ head_dim = inner_dim // attn.heads
77
+
78
+ query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
79
+ key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
80
+ value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
81
+
82
+ # Reshape query for RoPE
83
+ query = query.permute(0, 1, 3, 2)
84
+
85
+ # Apply RoPE if needed
86
+ if rotary_freqs_cis is not None:
87
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
88
+ # For cross-attention, you might have separate freqs for text
89
+ # but we assume they share for simplicity here
90
+ key_freqs = kwargs.get("rotary_freqs_cis_cross", rotary_freqs_cis)
91
+ key = self.apply_rotary_emb(key, key_freqs)
92
+
93
+ # Reshape query back
94
+ query = query.permute(0, 1, 3, 2)
95
+
96
+ # Linear attention math
97
+ query = self.kernel_func(query)
98
+ key = self.kernel_func(key)
99
+
100
+ query, key, value = query.float(), key.float(), value.float()
101
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
102
+ vk = torch.matmul(value, key)
103
+ hidden_states = torch.matmul(vk, query)
104
+
105
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
106
+ hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
107
+
108
+ hidden_states = hidden_states.to(query.dtype)
109
+
110
+ # linear proj
111
+ hidden_states = attn.to_out[0](hidden_states)
112
+ # dropout
113
+ hidden_states = attn.to_out[1](hidden_states)
114
+
115
+ if input_ndim == 4:
116
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
117
+
118
+ return hidden_states
119
+ class CustomLiteLAProcessor2_0:
120
+ """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
121
+
122
+ def __init__(self):
123
+ self.kernel_func = nn.ReLU(inplace=False)
124
+ self.eps = 1e-15
125
+ self.pad_val = 1.0
126
+
127
+ def apply_rotary_emb(
128
+ self,
129
+ x: torch.Tensor,
130
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
131
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ """
133
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
134
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
135
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
136
+ tensors contain rotary embeddings and are returned as real tensors.
137
+
138
+ Args:
139
+ x (`torch.Tensor`):
140
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
141
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
142
+
143
+ Returns:
144
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
145
+ """
146
+ cos, sin = freqs_cis # [S, D]
147
+ cos = cos[None, None]
148
+ sin = sin[None, None]
149
+ cos, sin = cos.to(x.device), sin.to(x.device)
150
+
151
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
152
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
153
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
154
+
155
+ return out
156
+
157
+ def __call__(
158
+ self,
159
+ attn: Attention,
160
+ hidden_states: torch.FloatTensor,
161
+ encoder_hidden_states: torch.FloatTensor = None,
162
+ attention_mask: Optional[torch.FloatTensor] = None,
163
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
164
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
165
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
166
+ *args,
167
+ **kwargs,
168
+ ) -> torch.FloatTensor:
169
+ hidden_states_len = hidden_states.shape[1]
170
+
171
+ input_ndim = hidden_states.ndim
172
+ if input_ndim == 4:
173
+ batch_size, channel, height, width = hidden_states.shape
174
+ hidden_states = hidden_states.view(
175
+ batch_size, channel, height * width
176
+ ).transpose(1, 2)
177
+ if encoder_hidden_states is not None:
178
+ context_input_ndim = encoder_hidden_states.ndim
179
+ if context_input_ndim == 4:
180
+ batch_size, channel, height, width = encoder_hidden_states.shape
181
+ encoder_hidden_states = encoder_hidden_states.view(
182
+ batch_size, channel, height * width
183
+ ).transpose(1, 2)
184
+
185
+ batch_size = hidden_states.shape[0]
186
+
187
+ # `sample` projections.
188
+ dtype = hidden_states.dtype
189
+ query = attn.to_q(hidden_states)
190
+ key = attn.to_k(hidden_states)
191
+ value = attn.to_v(hidden_states)
192
+
193
+ # `context` projections.
194
+ has_encoder_hidden_state_proj = (
195
+ hasattr(attn, "add_q_proj")
196
+ and hasattr(attn, "add_k_proj")
197
+ and hasattr(attn, "add_v_proj")
198
+ )
199
+ if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
200
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
201
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
202
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
203
+
204
+ # attention
205
+ if not attn.is_cross_attention:
206
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
207
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
208
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
209
+ else:
210
+ query = hidden_states
211
+ key = encoder_hidden_states
212
+ value = encoder_hidden_states
213
+
214
+ inner_dim = key.shape[-1]
215
+ head_dim = inner_dim // attn.heads
216
+
217
+ query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
218
+ key = (
219
+ key.transpose(-1, -2)
220
+ .reshape(batch_size, attn.heads, head_dim, -1)
221
+ .transpose(-1, -2)
222
+ )
223
+ value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
224
+
225
+ # RoPE需要 [B, H, S, D] 输入
226
+ # 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
227
+ query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
228
+
229
+ # Apply query and key normalization if needed
230
+ if attn.norm_q is not None:
231
+ query = attn.norm_q(query)
232
+ if attn.norm_k is not None:
233
+ key = attn.norm_k(key)
234
+
235
+ # Apply RoPE if needed
236
+ if rotary_freqs_cis is not None:
237
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
238
+ if not attn.is_cross_attention:
239
+ key = self.apply_rotary_emb(key, rotary_freqs_cis)
240
+ elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
241
+ key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
242
+
243
+ # 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
244
+ query = query.permute(0, 1, 3, 2) # [B, H, D, S]
245
+
246
+ if attention_mask is not None:
247
+ # attention_mask: [B, S] -> [B, 1, S, 1]
248
+ attention_mask = attention_mask[:, None, :, None].to(
249
+ key.dtype
250
+ ) # [B, 1, S, 1]
251
+ query = query * attention_mask.permute(
252
+ 0, 1, 3, 2
253
+ ) # [B, H, S, D] * [B, 1, S, 1]
254
+ if not attn.is_cross_attention:
255
+ key = (
256
+ key * attention_mask
257
+ ) # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
258
+ value = value * attention_mask.permute(
259
+ 0, 1, 3, 2
260
+ ) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
261
+
262
+ if (
263
+ attn.is_cross_attention
264
+ and encoder_attention_mask is not None
265
+ and has_encoder_hidden_state_proj
266
+ ):
267
+ encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(
268
+ key.dtype
269
+ ) # [B, 1, S_enc, 1]
270
+ # 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
271
+ key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
272
+ value = value * encoder_attention_mask.permute(
273
+ 0, 1, 3, 2
274
+ ) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
275
+
276
+ query = self.kernel_func(query)
277
+ key = self.kernel_func(key)
278
+
279
+ query, key, value = query.float(), key.float(), value.float()
280
+
281
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
282
+
283
+ vk = torch.matmul(value, key)
284
+
285
+ hidden_states = torch.matmul(vk, query)
286
+
287
+ if hidden_states.dtype in [torch.float16, torch.bfloat16]:
288
+ hidden_states = hidden_states.float()
289
+
290
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
291
+
292
+ hidden_states = hidden_states.view(
293
+ batch_size, attn.heads * head_dim, -1
294
+ ).permute(0, 2, 1)
295
+
296
+ hidden_states = hidden_states.to(dtype)
297
+ if encoder_hidden_states is not None:
298
+ encoder_hidden_states = encoder_hidden_states.to(dtype)
299
+
300
+ # Split the attention outputs.
301
+ if (
302
+ encoder_hidden_states is not None
303
+ and not attn.is_cross_attention
304
+ and has_encoder_hidden_state_proj
305
+ ):
306
+ hidden_states, encoder_hidden_states = (
307
+ hidden_states[:, :hidden_states_len],
308
+ hidden_states[:, hidden_states_len:],
309
+ )
310
+
311
+ # linear proj
312
+ hidden_states = attn.to_out[0](hidden_states)
313
+ # dropout
314
+ hidden_states = attn.to_out[1](hidden_states)
315
+ # if (
316
+ # encoder_hidden_states is not None
317
+ # and not attn.context_pre_only
318
+ # and not attn.is_cross_attention
319
+ # and hasattr(attn, "to_add_out")
320
+ # ):
321
+ # encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
322
+
323
+ if input_ndim == 4:
324
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
325
+ batch_size, channel, height, width
326
+ )
327
+ if encoder_hidden_states is not None and context_input_ndim == 4:
328
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
329
+ batch_size, channel, height, width
330
+ )
331
+
332
+ if torch.get_autocast_gpu_dtype() == torch.float16:
333
+ hidden_states = hidden_states.clip(-65504, 65504)
334
+ if encoder_hidden_states is not None:
335
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
336
+
337
+ return hidden_states, encoder_hidden_states
338
+
339
+
340
+ class CustomerAttnProcessor2_0:
341
+ r"""
342
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
343
+ """
344
+
345
+ def __init__(self):
346
+ if not hasattr(F, "scaled_dot_product_attention"):
347
+ raise ImportError(
348
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
349
+ )
350
+
351
+ def apply_rotary_emb(
352
+ self,
353
+ x: torch.Tensor,
354
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
355
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
356
+ """
357
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
358
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
359
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
360
+ tensors contain rotary embeddings and are returned as real tensors.
361
+
362
+ Args:
363
+ x (`torch.Tensor`):
364
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
365
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
366
+
367
+ Returns:
368
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
369
+ """
370
+ cos, sin = freqs_cis # [S, D]
371
+ cos = cos[None, None]
372
+ sin = sin[None, None]
373
+ cos, sin = cos.to(x.device), sin.to(x.device)
374
+
375
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
376
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
377
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
378
+
379
+ return out
380
+
381
+ def __call__(
382
+ self,
383
+ attn: Attention,
384
+ hidden_states: torch.FloatTensor,
385
+ encoder_hidden_states: torch.FloatTensor = None,
386
+ attention_mask: Optional[torch.FloatTensor] = None,
387
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
388
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
389
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
390
+ *args,
391
+ **kwargs,
392
+ ) -> torch.Tensor:
393
+
394
+ residual = hidden_states
395
+ input_ndim = hidden_states.ndim
396
+
397
+ if input_ndim == 4:
398
+ batch_size, channel, height, width = hidden_states.shape
399
+ hidden_states = hidden_states.view(
400
+ batch_size, channel, height * width
401
+ ).transpose(1, 2)
402
+
403
+ batch_size, sequence_length, _ = (
404
+ hidden_states.shape
405
+ if encoder_hidden_states is None
406
+ else encoder_hidden_states.shape
407
+ )
408
+
409
+ has_encoder_hidden_state_proj = (
410
+ hasattr(attn, "add_q_proj")
411
+ and hasattr(attn, "add_k_proj")
412
+ and hasattr(attn, "add_v_proj")
413
+ )
414
+
415
+ if attn.group_norm is not None:
416
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
417
+ 1, 2
418
+ )
419
+
420
+ query = attn.to_q(hidden_states)
421
+
422
+ if encoder_hidden_states is None:
423
+ encoder_hidden_states = hidden_states
424
+ elif attn.norm_cross:
425
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
426
+ encoder_hidden_states
427
+ )
428
+
429
+ key = attn.to_k(encoder_hidden_states)
430
+ value = attn.to_v(encoder_hidden_states)
431
+
432
+ inner_dim = key.shape[-1]
433
+ head_dim = inner_dim // attn.heads
434
+
435
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
436
+
437
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
438
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
439
+
440
+ if attn.norm_q is not None:
441
+ query = attn.norm_q(query)
442
+ if attn.norm_k is not None:
443
+ key = attn.norm_k(key)
444
+
445
+ # Apply RoPE if needed
446
+ if rotary_freqs_cis is not None:
447
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
448
+ if not attn.is_cross_attention:
449
+ key = self.apply_rotary_emb(key, rotary_freqs_cis)
450
+ elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
451
+ key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
452
+
453
+ if (
454
+ attn.is_cross_attention
455
+ and encoder_attention_mask is not None
456
+ and has_encoder_hidden_state_proj
457
+ ):
458
+ # attention_mask: N x S1
459
+ # encoder_attention_mask: N x S2
460
+ # cross attention 整合attention_mask和encoder_attention_mask
461
+ combined_mask = (
462
+ attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
463
+ )
464
+ attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
465
+ attention_mask = (
466
+ attention_mask[:, None, :, :]
467
+ .expand(-1, attn.heads, -1, -1)
468
+ .to(query.dtype)
469
+ )
470
+
471
+ elif not attn.is_cross_attention and attention_mask is not None:
472
+ attention_mask = attn.prepare_attention_mask(
473
+ attention_mask, sequence_length, batch_size
474
+ )
475
+ # scaled_dot_product_attention expects attention_mask shape to be
476
+ # (batch, heads, source_length, target_length)
477
+ attention_mask = attention_mask.view(
478
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
479
+ )
480
+
481
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
482
+ # TODO: add support for attn.scale when we move to Torch 2.1
483
+ hidden_states = F.scaled_dot_product_attention(
484
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
485
+ )
486
+
487
+ hidden_states = hidden_states.transpose(1, 2).reshape(
488
+ batch_size, -1, attn.heads * head_dim
489
+ )
490
+ hidden_states = hidden_states.to(query.dtype)
491
+
492
+ # linear proj
493
+ hidden_states = attn.to_out[0](hidden_states)
494
+ # dropout
495
+ hidden_states = attn.to_out[1](hidden_states)
496
+
497
+ if input_ndim == 4:
498
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
499
+ batch_size, channel, height, width
500
+ )
501
+
502
+ if attn.residual_connection:
503
+ hidden_states = hidden_states + residual
504
+
505
+ hidden_states = hidden_states / attn.rescale_output_factor
506
+
507
+ return hidden_states
model/ldm/dpm_solver_pytorch.py ADDED
@@ -0,0 +1,1307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ dtype=torch.float32,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+
18
+ ***
19
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
20
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
21
+ ***
22
+
23
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
24
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
25
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
26
+
27
+ log_alpha_t = self.marginal_log_mean_coeff(t)
28
+ sigma_t = self.marginal_std(t)
29
+ lambda_t = self.marginal_lambda(t)
30
+
31
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
32
+
33
+ t = self.inverse_lambda(lambda_t)
34
+
35
+ ===============================================================
36
+
37
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
38
+
39
+ 1. For discrete-time DPMs:
40
+
41
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
42
+ t_i = (i + 1) / N
43
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
44
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
45
+
46
+ Args:
47
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
48
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
49
+
50
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
51
+
52
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
53
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
54
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
55
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
56
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
57
+ and
58
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
59
+
60
+
61
+ 2. For continuous-time DPMs:
62
+
63
+ We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
64
+ schedule are the default settings in Yang Song's ScoreSDE:
65
+
66
+ Args:
67
+ beta_min: A `float` number. The smallest beta for the linear schedule.
68
+ beta_max: A `float` number. The largest beta for the linear schedule.
69
+ T: A `float` number. The ending time of the forward process.
70
+
71
+ ===============================================================
72
+
73
+ Args:
74
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
75
+ 'linear' for continuous-time DPMs.
76
+ Returns:
77
+ A wrapper object of the forward SDE (VP type).
78
+
79
+ ===============================================================
80
+
81
+ Example:
82
+
83
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
84
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
85
+
86
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
87
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
88
+
89
+ # For continuous-time DPMs (VPSDE), linear schedule:
90
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
91
+
92
+ """
93
+
94
+ if schedule not in ['discrete', 'linear']:
95
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule))
96
+
97
+ self.schedule = schedule
98
+ if schedule == 'discrete':
99
+ if betas is not None:
100
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
101
+ else:
102
+ assert alphas_cumprod is not None
103
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
104
+ self.T = 1.
105
+ self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
106
+ self.total_N = self.log_alpha_array.shape[1]
107
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
108
+ else:
109
+ self.T = 1.
110
+ self.total_N = 1000
111
+ self.beta_0 = continuous_beta_0
112
+ self.beta_1 = continuous_beta_1
113
+
114
+ def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
115
+ """
116
+ For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
117
+ We clip the log-SNR near t=T within -5.1 to ensure the stability.
118
+ Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
119
+ """
120
+ log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
121
+ lambs = log_alphas - log_sigmas
122
+ idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
123
+ if idx > 0:
124
+ log_alphas = log_alphas[:-idx]
125
+ return log_alphas
126
+
127
+ def marginal_log_mean_coeff(self, t):
128
+ """
129
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
130
+ """
131
+ if self.schedule == 'discrete':
132
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
133
+ elif self.schedule == 'linear':
134
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
135
+
136
+ def marginal_alpha(self, t):
137
+ """
138
+ Compute alpha_t of a given continuous-time label t in [0, T].
139
+ """
140
+ return torch.exp(self.marginal_log_mean_coeff(t))
141
+
142
+ def marginal_std(self, t):
143
+ """
144
+ Compute sigma_t of a given continuous-time label t in [0, T].
145
+ """
146
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
147
+
148
+ def marginal_lambda(self, t):
149
+ """
150
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
151
+ """
152
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
153
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
154
+ return log_mean_coeff - log_std
155
+
156
+ def inverse_lambda(self, lamb):
157
+ """
158
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
159
+ """
160
+ if self.schedule == 'linear':
161
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
162
+ Delta = self.beta_0**2 + tmp
163
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
164
+ elif self.schedule == 'discrete':
165
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
166
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
167
+ return t.reshape((-1,))
168
+
169
+
170
+ def model_wrapper(
171
+ model,
172
+ noise_schedule,
173
+ model_type="noise",
174
+ model_kwargs={},
175
+ guidance_type="uncond",
176
+ condition=None,
177
+ unconditional_condition=None,
178
+ guidance_scale=1.,
179
+ classifier_fn=None,
180
+ classifier_kwargs={},
181
+ ):
182
+ """Create a wrapper function for the noise prediction model.
183
+
184
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
185
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
186
+
187
+ We support four types of the diffusion model by setting `model_type`:
188
+
189
+ 1. "noise": noise prediction model. (Trained by predicting noise).
190
+
191
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
192
+
193
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
194
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
195
+
196
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
197
+ arXiv preprint arXiv:2202.00512 (2022).
198
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
199
+ arXiv preprint arXiv:2210.02303 (2022).
200
+
201
+ 4. "score": marginal score function. (Trained by denoising score matching).
202
+ Note that the score function and the noise prediction model follows a simple relationship:
203
+ ```
204
+ noise(x_t, t) = -sigma_t * score(x_t, t)
205
+ ```
206
+
207
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
208
+ 1. "uncond": unconditional sampling by DPMs.
209
+ The input `model` has the following format:
210
+ ``
211
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
212
+ ``
213
+
214
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
215
+ The input `model` has the following format:
216
+ ``
217
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
218
+ ``
219
+
220
+ The input `classifier_fn` has the following format:
221
+ ``
222
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
223
+ ``
224
+
225
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
226
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
227
+
228
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
229
+ The input `model` has the following format:
230
+ ``
231
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
232
+ ``
233
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
234
+
235
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
236
+ arXiv preprint arXiv:2207.12598 (2022).
237
+
238
+
239
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
240
+ or continuous-time labels (i.e. epsilon to T).
241
+
242
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
243
+ ``
244
+ def model_fn(x, t_continuous) -> noise:
245
+ t_input = get_model_input_time(t_continuous)
246
+ return noise_pred(model, x, t_input, **model_kwargs)
247
+ ``
248
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
249
+
250
+ ===============================================================
251
+
252
+ Args:
253
+ model: A diffusion model with the corresponding format described above.
254
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
255
+ model_type: A `str`. The parameterization type of the diffusion model.
256
+ "noise" or "x_start" or "v" or "score".
257
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
258
+ guidance_type: A `str`. The type of the guidance for sampling.
259
+ "uncond" or "classifier" or "classifier-free".
260
+ condition: A pytorch tensor. The condition for the guided sampling.
261
+ Only used for "classifier" or "classifier-free" guidance type.
262
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
263
+ Only used for "classifier-free" guidance type.
264
+ guidance_scale: A `float`. The scale for the guided sampling.
265
+ classifier_fn: A classifier function. Only used for the classifier guidance.
266
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
267
+ Returns:
268
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
269
+ """
270
+
271
+ def get_model_input_time(t_continuous):
272
+ """
273
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
274
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
275
+ For continuous-time DPMs, we just use `t_continuous`.
276
+ """
277
+ if noise_schedule.schedule == 'discrete':
278
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
279
+ else:
280
+ return t_continuous
281
+
282
+ def noise_pred_fn(x, t_continuous, cond=None):
283
+ t_input = get_model_input_time(t_continuous)
284
+ if cond is None:
285
+ # For EditingUNet: (noisy_target_latent, source_latent, context, timesteps)
286
+ output = model(noisy_target_latent=x, timesteps=t_input, **model_kwargs)
287
+ else:
288
+ # For EditingUNet with condition: (noisy_target_latent, source_latent, context, timesteps)
289
+ output = model(noisy_target_latent=x, context=cond, timesteps=t_input, **model_kwargs)
290
+ if model_type == "noise":
291
+ return output
292
+ elif model_type == "x_start":
293
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
294
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
295
+ elif model_type == "v":
296
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
297
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
298
+ elif model_type == "score":
299
+ sigma_t = noise_schedule.marginal_std(t_continuous)
300
+ return -expand_dims(sigma_t, x.dim()) * output
301
+
302
+ def cond_grad_fn(x, t_input):
303
+ """
304
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
305
+ """
306
+ with torch.enable_grad():
307
+ x_in = x.detach().requires_grad_(True)
308
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
309
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
310
+
311
+ def model_fn(x, t_continuous):
312
+ """
313
+ The noise predicition model function that is used for DPM-Solver.
314
+ """
315
+ if guidance_type == "uncond":
316
+ return noise_pred_fn(x, t_continuous)
317
+ elif guidance_type == "classifier":
318
+ assert classifier_fn is not None
319
+ t_input = get_model_input_time(t_continuous)
320
+ cond_grad = cond_grad_fn(x, t_input)
321
+ sigma_t = noise_schedule.marginal_std(t_continuous)
322
+ noise = noise_pred_fn(x, t_continuous)
323
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
324
+ elif guidance_type == "classifier-free":
325
+ if guidance_scale == 1. or unconditional_condition is None:
326
+ return noise_pred_fn(x, t_continuous, cond=condition)
327
+ else:
328
+ x_in = torch.cat([x] * 2)
329
+ t_in = torch.cat([t_continuous] * 2)
330
+ c_in = torch.cat([unconditional_condition, condition])
331
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
332
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
333
+
334
+ assert model_type in ["noise", "x_start", "v", "score"]
335
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
336
+ return model_fn
337
+
338
+
339
+ class DPM_Solver:
340
+ def __init__(
341
+ self,
342
+ model_fn,
343
+ noise_schedule,
344
+ algorithm_type="dpmsolver++",
345
+ correcting_x0_fn=None,
346
+ correcting_xt_fn=None,
347
+ thresholding_max_val=1.,
348
+ dynamic_thresholding_ratio=0.995,
349
+ ):
350
+ """Construct a DPM-Solver.
351
+
352
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
353
+
354
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
355
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
356
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
357
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
358
+ DPMs (such as stable-diffusion).
359
+
360
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
361
+ both x0 and xt.
362
+
363
+ Args:
364
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
365
+ ``
366
+ def model_fn(x, t_continuous):
367
+ return noise
368
+ ``
369
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
370
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
371
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
372
+ correcting_x0_fn: A `str` or a function with the following format:
373
+ ```
374
+ def correcting_x0_fn(x0, t):
375
+ x0_new = ...
376
+ return x0_new
377
+ ```
378
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
379
+ ```
380
+ x0_pred = data_pred_model(xt, t)
381
+ if correcting_x0_fn is not None:
382
+ x0_pred = correcting_x0_fn(x0_pred, t)
383
+ xt_1 = update(x0_pred, xt, t)
384
+ ```
385
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
386
+ correcting_xt_fn: A function with the following format:
387
+ ```
388
+ def correcting_xt_fn(xt, t, step):
389
+ x_new = ...
390
+ return x_new
391
+ ```
392
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
393
+ ```
394
+ xt = ...
395
+ xt = correcting_xt_fn(xt, t, step)
396
+ ```
397
+ thresholding_max_val: A `float`. The max value for thresholding.
398
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
399
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
400
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
401
+
402
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
403
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
404
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
405
+ """
406
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
407
+ self.noise_schedule = noise_schedule
408
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
409
+ self.algorithm_type = algorithm_type
410
+ if correcting_x0_fn == "dynamic_thresholding":
411
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
412
+ else:
413
+ self.correcting_x0_fn = correcting_x0_fn
414
+ self.correcting_xt_fn = correcting_xt_fn
415
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
416
+ self.thresholding_max_val = thresholding_max_val
417
+
418
+ def dynamic_thresholding_fn(self, x0, t):
419
+ """
420
+ The dynamic thresholding method.
421
+ """
422
+ dims = x0.dim()
423
+ p = self.dynamic_thresholding_ratio
424
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
425
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
426
+ x0 = torch.clamp(x0, -s, s) / s
427
+ return x0
428
+
429
+ def noise_prediction_fn(self, x, t):
430
+ """
431
+ Return the noise prediction model.
432
+ """
433
+ return self.model(x, t)
434
+
435
+ def data_prediction_fn(self, x, t):
436
+ """
437
+ Return the data prediction model (with corrector).
438
+ """
439
+ noise = self.noise_prediction_fn(x, t)
440
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
441
+ x0 = (x - sigma_t * noise) / alpha_t
442
+ if self.correcting_x0_fn is not None:
443
+ x0 = self.correcting_x0_fn(x0, t)
444
+ return x0
445
+
446
+ def model_fn(self, x, t):
447
+ """
448
+ Convert the model to the noise prediction model or the data prediction model.
449
+ """
450
+ if self.algorithm_type == "dpmsolver++":
451
+ return self.data_prediction_fn(x, t)
452
+ else:
453
+ return self.noise_prediction_fn(x, t)
454
+
455
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
456
+ """Compute the intermediate time steps for sampling.
457
+
458
+ Args:
459
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
460
+ - 'logSNR': uniform logSNR for the time steps.
461
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
462
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
463
+ t_T: A `float`. The starting time of the sampling (default is T).
464
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
465
+ N: A `int`. The total number of the spacing of the time steps.
466
+ device: A torch device.
467
+ Returns:
468
+ A pytorch tensor of the time steps, with the shape (N + 1,).
469
+ """
470
+ if skip_type == 'logSNR':
471
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
472
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
473
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
474
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
475
+ elif skip_type == 'time_uniform':
476
+ return torch.linspace(t_T, t_0, N + 1).to(device)
477
+ elif skip_type == 'time_quadratic':
478
+ t_order = 2
479
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
480
+ return t
481
+ else:
482
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
483
+
484
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
485
+ """
486
+ Get the order of each step for sampling by the singlestep DPM-Solver.
487
+
488
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
489
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
490
+ - If order == 1:
491
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
492
+ - If order == 2:
493
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
494
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
495
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
496
+ - If order == 3:
497
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
498
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
499
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
500
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
501
+
502
+ ============================================
503
+ Args:
504
+ order: A `int`. The max order for the solver (2 or 3).
505
+ steps: A `int`. The total number of function evaluations (NFE).
506
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
507
+ - 'logSNR': uniform logSNR for the time steps.
508
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
509
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
510
+ t_T: A `float`. The starting time of the sampling (default is T).
511
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
512
+ device: A torch device.
513
+ Returns:
514
+ orders: A list of the solver order of each step.
515
+ """
516
+ if order == 3:
517
+ K = steps // 3 + 1
518
+ if steps % 3 == 0:
519
+ orders = [3,] * (K - 2) + [2, 1]
520
+ elif steps % 3 == 1:
521
+ orders = [3,] * (K - 1) + [1]
522
+ else:
523
+ orders = [3,] * (K - 1) + [2]
524
+ elif order == 2:
525
+ if steps % 2 == 0:
526
+ K = steps // 2
527
+ orders = [2,] * K
528
+ else:
529
+ K = steps // 2 + 1
530
+ orders = [2,] * (K - 1) + [1]
531
+ elif order == 1:
532
+ K = steps
533
+ orders = [1,] * steps
534
+ else:
535
+ raise ValueError("'order' must be '1' or '2' or '3'.")
536
+ if skip_type == 'logSNR':
537
+ # To reproduce the results in DPM-Solver paper
538
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
539
+ else:
540
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
541
+ return timesteps_outer, orders
542
+
543
+ def denoise_to_zero_fn(self, x, s):
544
+ """
545
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
546
+ """
547
+ return self.data_prediction_fn(x, s)
548
+
549
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
550
+ """
551
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
552
+
553
+ Args:
554
+ x: A pytorch tensor. The initial value at time `s`.
555
+ s: A pytorch tensor. The starting time, with the shape (1,).
556
+ t: A pytorch tensor. The ending time, with the shape (1,).
557
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
558
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
559
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
560
+ Returns:
561
+ x_t: A pytorch tensor. The approximated solution at time `t`.
562
+ """
563
+ ns = self.noise_schedule
564
+ dims = x.dim()
565
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
566
+ h = lambda_t - lambda_s
567
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
568
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
569
+ alpha_t = torch.exp(log_alpha_t)
570
+
571
+ if self.algorithm_type == "dpmsolver++":
572
+ phi_1 = torch.expm1(-h)
573
+ if model_s is None:
574
+ model_s = self.model_fn(x, s)
575
+ x_t = (
576
+ sigma_t / sigma_s * x
577
+ - alpha_t * phi_1 * model_s
578
+ )
579
+ if return_intermediate:
580
+ return x_t, {'model_s': model_s}
581
+ else:
582
+ return x_t
583
+ else:
584
+ phi_1 = torch.expm1(h)
585
+ if model_s is None:
586
+ model_s = self.model_fn(x, s)
587
+ x_t = (
588
+ torch.exp(log_alpha_t - log_alpha_s) * x
589
+ - (sigma_t * phi_1) * model_s
590
+ )
591
+ if return_intermediate:
592
+ return x_t, {'model_s': model_s}
593
+ else:
594
+ return x_t
595
+
596
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'):
597
+ """
598
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
599
+
600
+ Args:
601
+ x: A pytorch tensor. The initial value at time `s`.
602
+ s: A pytorch tensor. The starting time, with the shape (1,).
603
+ t: A pytorch tensor. The ending time, with the shape (1,).
604
+ r1: A `float`. The hyperparameter of the second-order solver.
605
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
606
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
607
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
608
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
609
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
610
+ Returns:
611
+ x_t: A pytorch tensor. The approximated solution at time `t`.
612
+ """
613
+ if solver_type not in ['dpmsolver', 'taylor']:
614
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
615
+ if r1 is None:
616
+ r1 = 0.5
617
+ ns = self.noise_schedule
618
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
619
+ h = lambda_t - lambda_s
620
+ lambda_s1 = lambda_s + r1 * h
621
+ s1 = ns.inverse_lambda(lambda_s1)
622
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
623
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
624
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
625
+
626
+ if self.algorithm_type == "dpmsolver++":
627
+ phi_11 = torch.expm1(-r1 * h)
628
+ phi_1 = torch.expm1(-h)
629
+
630
+ if model_s is None:
631
+ model_s = self.model_fn(x, s)
632
+ x_s1 = (
633
+ (sigma_s1 / sigma_s) * x
634
+ - (alpha_s1 * phi_11) * model_s
635
+ )
636
+ model_s1 = self.model_fn(x_s1, s1)
637
+ if solver_type == 'dpmsolver':
638
+ x_t = (
639
+ (sigma_t / sigma_s) * x
640
+ - (alpha_t * phi_1) * model_s
641
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
642
+ )
643
+ elif solver_type == 'taylor':
644
+ x_t = (
645
+ (sigma_t / sigma_s) * x
646
+ - (alpha_t * phi_1) * model_s
647
+ + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
648
+ )
649
+ else:
650
+ phi_11 = torch.expm1(r1 * h)
651
+ phi_1 = torch.expm1(h)
652
+
653
+ if model_s is None:
654
+ model_s = self.model_fn(x, s)
655
+ x_s1 = (
656
+ torch.exp(log_alpha_s1 - log_alpha_s) * x
657
+ - (sigma_s1 * phi_11) * model_s
658
+ )
659
+ model_s1 = self.model_fn(x_s1, s1)
660
+ if solver_type == 'dpmsolver':
661
+ x_t = (
662
+ torch.exp(log_alpha_t - log_alpha_s) * x
663
+ - (sigma_t * phi_1) * model_s
664
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
665
+ )
666
+ elif solver_type == 'taylor':
667
+ x_t = (
668
+ torch.exp(log_alpha_t - log_alpha_s) * x
669
+ - (sigma_t * phi_1) * model_s
670
+ - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
671
+ )
672
+ if return_intermediate:
673
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
674
+ else:
675
+ return x_t
676
+
677
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'):
678
+ """
679
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
680
+
681
+ Args:
682
+ x: A pytorch tensor. The initial value at time `s`.
683
+ s: A pytorch tensor. The starting time, with the shape (1,).
684
+ t: A pytorch tensor. The ending time, with the shape (1,).
685
+ r1: A `float`. The hyperparameter of the third-order solver.
686
+ r2: A `float`. The hyperparameter of the third-order solver.
687
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
688
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
689
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
690
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
691
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
692
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
693
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
694
+ Returns:
695
+ x_t: A pytorch tensor. The approximated solution at time `t`.
696
+ """
697
+ if solver_type not in ['dpmsolver', 'taylor']:
698
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
699
+ if r1 is None:
700
+ r1 = 1. / 3.
701
+ if r2 is None:
702
+ r2 = 2. / 3.
703
+ ns = self.noise_schedule
704
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
705
+ h = lambda_t - lambda_s
706
+ lambda_s1 = lambda_s + r1 * h
707
+ lambda_s2 = lambda_s + r2 * h
708
+ s1 = ns.inverse_lambda(lambda_s1)
709
+ s2 = ns.inverse_lambda(lambda_s2)
710
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
711
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
712
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
713
+
714
+ if self.algorithm_type == "dpmsolver++":
715
+ phi_11 = torch.expm1(-r1 * h)
716
+ phi_12 = torch.expm1(-r2 * h)
717
+ phi_1 = torch.expm1(-h)
718
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
719
+ phi_2 = phi_1 / h + 1.
720
+ phi_3 = phi_2 / h - 0.5
721
+
722
+ if model_s is None:
723
+ model_s = self.model_fn(x, s)
724
+ if model_s1 is None:
725
+ x_s1 = (
726
+ (sigma_s1 / sigma_s) * x
727
+ - (alpha_s1 * phi_11) * model_s
728
+ )
729
+ model_s1 = self.model_fn(x_s1, s1)
730
+ x_s2 = (
731
+ (sigma_s2 / sigma_s) * x
732
+ - (alpha_s2 * phi_12) * model_s
733
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
734
+ )
735
+ model_s2 = self.model_fn(x_s2, s2)
736
+ if solver_type == 'dpmsolver':
737
+ x_t = (
738
+ (sigma_t / sigma_s) * x
739
+ - (alpha_t * phi_1) * model_s
740
+ + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
741
+ )
742
+ elif solver_type == 'taylor':
743
+ D1_0 = (1. / r1) * (model_s1 - model_s)
744
+ D1_1 = (1. / r2) * (model_s2 - model_s)
745
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
746
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
747
+ x_t = (
748
+ (sigma_t / sigma_s) * x
749
+ - (alpha_t * phi_1) * model_s
750
+ + (alpha_t * phi_2) * D1
751
+ - (alpha_t * phi_3) * D2
752
+ )
753
+ else:
754
+ phi_11 = torch.expm1(r1 * h)
755
+ phi_12 = torch.expm1(r2 * h)
756
+ phi_1 = torch.expm1(h)
757
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
758
+ phi_2 = phi_1 / h - 1.
759
+ phi_3 = phi_2 / h - 0.5
760
+
761
+ if model_s is None:
762
+ model_s = self.model_fn(x, s)
763
+ if model_s1 is None:
764
+ x_s1 = (
765
+ (torch.exp(log_alpha_s1 - log_alpha_s)) * x
766
+ - (sigma_s1 * phi_11) * model_s
767
+ )
768
+ model_s1 = self.model_fn(x_s1, s1)
769
+ x_s2 = (
770
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
771
+ - (sigma_s2 * phi_12) * model_s
772
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
773
+ )
774
+ model_s2 = self.model_fn(x_s2, s2)
775
+ if solver_type == 'dpmsolver':
776
+ x_t = (
777
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
778
+ - (sigma_t * phi_1) * model_s
779
+ - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
780
+ )
781
+ elif solver_type == 'taylor':
782
+ D1_0 = (1. / r1) * (model_s1 - model_s)
783
+ D1_1 = (1. / r2) * (model_s2 - model_s)
784
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
785
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
786
+ x_t = (
787
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
788
+ - (sigma_t * phi_1) * model_s
789
+ - (sigma_t * phi_2) * D1
790
+ - (sigma_t * phi_3) * D2
791
+ )
792
+
793
+ if return_intermediate:
794
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
795
+ else:
796
+ return x_t
797
+
798
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
799
+ """
800
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
801
+
802
+ Args:
803
+ x: A pytorch tensor. The initial value at time `s`.
804
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
805
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
806
+ t: A pytorch tensor. The ending time, with the shape (1,).
807
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
808
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
809
+ Returns:
810
+ x_t: A pytorch tensor. The approximated solution at time `t`.
811
+ """
812
+ if solver_type not in ['dpmsolver', 'taylor']:
813
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
814
+ ns = self.noise_schedule
815
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
816
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
817
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
818
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
819
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
820
+ alpha_t = torch.exp(log_alpha_t)
821
+
822
+ h_0 = lambda_prev_0 - lambda_prev_1
823
+ h = lambda_t - lambda_prev_0
824
+ r0 = h_0 / h
825
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
826
+ if self.algorithm_type == "dpmsolver++":
827
+ phi_1 = torch.expm1(-h)
828
+ if solver_type == 'dpmsolver':
829
+ x_t = (
830
+ (sigma_t / sigma_prev_0) * x
831
+ - (alpha_t * phi_1) * model_prev_0
832
+ - 0.5 * (alpha_t * phi_1) * D1_0
833
+ )
834
+ elif solver_type == 'taylor':
835
+ x_t = (
836
+ (sigma_t / sigma_prev_0) * x
837
+ - (alpha_t * phi_1) * model_prev_0
838
+ + (alpha_t * (phi_1 / h + 1.)) * D1_0
839
+ )
840
+ else:
841
+ phi_1 = torch.expm1(h)
842
+ if solver_type == 'dpmsolver':
843
+ x_t = (
844
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
845
+ - (sigma_t * phi_1) * model_prev_0
846
+ - 0.5 * (sigma_t * phi_1) * D1_0
847
+ )
848
+ elif solver_type == 'taylor':
849
+ x_t = (
850
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
851
+ - (sigma_t * phi_1) * model_prev_0
852
+ - (sigma_t * (phi_1 / h - 1.)) * D1_0
853
+ )
854
+ return x_t
855
+
856
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
857
+ """
858
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
859
+
860
+ Args:
861
+ x: A pytorch tensor. The initial value at time `s`.
862
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
863
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
864
+ t: A pytorch tensor. The ending time, with the shape (1,).
865
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
866
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
867
+ Returns:
868
+ x_t: A pytorch tensor. The approximated solution at time `t`.
869
+ """
870
+ ns = self.noise_schedule
871
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
872
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
873
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
874
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
875
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
876
+ alpha_t = torch.exp(log_alpha_t)
877
+
878
+ h_1 = lambda_prev_1 - lambda_prev_2
879
+ h_0 = lambda_prev_0 - lambda_prev_1
880
+ h = lambda_t - lambda_prev_0
881
+ r0, r1 = h_0 / h, h_1 / h
882
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
883
+ D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
884
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
885
+ D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
886
+ if self.algorithm_type == "dpmsolver++":
887
+ phi_1 = torch.expm1(-h)
888
+ phi_2 = phi_1 / h + 1.
889
+ phi_3 = phi_2 / h - 0.5
890
+ x_t = (
891
+ (sigma_t / sigma_prev_0) * x
892
+ - (alpha_t * phi_1) * model_prev_0
893
+ + (alpha_t * phi_2) * D1
894
+ - (alpha_t * phi_3) * D2
895
+ )
896
+ else:
897
+ phi_1 = torch.expm1(h)
898
+ phi_2 = phi_1 / h - 1.
899
+ phi_3 = phi_2 / h - 0.5
900
+ x_t = (
901
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
902
+ - (sigma_t * phi_1) * model_prev_0
903
+ - (sigma_t * phi_2) * D1
904
+ - (sigma_t * phi_3) * D2
905
+ )
906
+ return x_t
907
+
908
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None):
909
+ """
910
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
911
+
912
+ Args:
913
+ x: A pytorch tensor. The initial value at time `s`.
914
+ s: A pytorch tensor. The starting time, with the shape (1,).
915
+ t: A pytorch tensor. The ending time, with the shape (1,).
916
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
917
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
918
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
919
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
920
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
921
+ r2: A `float`. The hyperparameter of the third-order solver.
922
+ Returns:
923
+ x_t: A pytorch tensor. The approximated solution at time `t`.
924
+ """
925
+ if order == 1:
926
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
927
+ elif order == 2:
928
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
929
+ elif order == 3:
930
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
931
+ else:
932
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
933
+
934
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
935
+ """
936
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
937
+
938
+ Args:
939
+ x: A pytorch tensor. The initial value at time `s`.
940
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
941
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
942
+ t: A pytorch tensor. The ending time, with the shape (1,).
943
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
944
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
945
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
946
+ Returns:
947
+ x_t: A pytorch tensor. The approximated solution at time `t`.
948
+ """
949
+ if order == 1:
950
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
951
+ elif order == 2:
952
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
953
+ elif order == 3:
954
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
955
+ else:
956
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
957
+
958
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'):
959
+ """
960
+ The adaptive step size solver based on singlestep DPM-Solver.
961
+
962
+ Args:
963
+ x: A pytorch tensor. The initial value at time `t_T`.
964
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
965
+ t_T: A `float`. The starting time of the sampling (default is T).
966
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
967
+ h_init: A `float`. The initial step size (for logSNR).
968
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
969
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
970
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
971
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
972
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
973
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
974
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
975
+ Returns:
976
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
977
+
978
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
979
+ """
980
+ ns = self.noise_schedule
981
+ s = t_T * torch.ones((1,)).to(x)
982
+ lambda_s = ns.marginal_lambda(s)
983
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
984
+ h = h_init * torch.ones_like(s).to(x)
985
+ x_prev = x
986
+ nfe = 0
987
+ if order == 2:
988
+ r1 = 0.5
989
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
990
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
991
+ elif order == 3:
992
+ r1, r2 = 1. / 3., 2. / 3.
993
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
994
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
995
+ else:
996
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
997
+ while torch.abs((s - t_0)).mean() > t_err:
998
+ t = ns.inverse_lambda(lambda_s + h)
999
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1000
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1001
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1002
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1003
+ E = norm_fn((x_higher - x_lower) / delta).max()
1004
+ if torch.all(E <= 1.):
1005
+ x = x_higher
1006
+ s = t
1007
+ x_prev = x_lower
1008
+ lambda_s = ns.marginal_lambda(s)
1009
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
1010
+ nfe += order
1011
+ print('adaptive solver nfe', nfe)
1012
+ return x
1013
+
1014
+ def add_noise(self, x, t, noise=None):
1015
+ """
1016
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1017
+
1018
+ Args:
1019
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1020
+ t: A `torch.Tensor` with shape `(t_size,)`.
1021
+ Returns:
1022
+ xt with shape `(t_size, batch_size, *shape)`.
1023
+ """
1024
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1025
+ if noise is None:
1026
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1027
+ x = x.reshape((-1, *x.shape))
1028
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1029
+ if t.shape[0] == 1:
1030
+ return xt.squeeze(0)
1031
+ else:
1032
+ return xt
1033
+
1034
+ def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1035
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1036
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1037
+ ):
1038
+ """
1039
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1040
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1041
+ """
1042
+ t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
1043
+ t_T = self.noise_schedule.T if t_end is None else t_end
1044
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1045
+ return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
1046
+ method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type,
1047
+ atol=atol, rtol=rtol, return_intermediate=return_intermediate)
1048
+
1049
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1050
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1051
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1052
+ ):
1053
+ """
1054
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1055
+
1056
+ =====================================================
1057
+
1058
+ We support the following algorithms for both noise prediction model and data prediction model:
1059
+ - 'singlestep':
1060
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1061
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1062
+ The total number of function evaluations (NFE) == `steps`.
1063
+ Given a fixed NFE == `steps`, the sampling procedure is:
1064
+ - If `order` == 1:
1065
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1066
+ - If `order` == 2:
1067
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1068
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1069
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1070
+ - If `order` == 3:
1071
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1072
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1073
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1074
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1075
+ - 'multistep':
1076
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1077
+ We initialize the first `order` values by lower order multistep solvers.
1078
+ Given a fixed NFE == `steps`, the sampling procedure is:
1079
+ Denote K = steps.
1080
+ - If `order` == 1:
1081
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1082
+ - If `order` == 2:
1083
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1084
+ - If `order` == 3:
1085
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1086
+ - 'singlestep_fixed':
1087
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1088
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1089
+ - 'adaptive':
1090
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1091
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1092
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1093
+ (NFE) and the sample quality.
1094
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1095
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1096
+
1097
+ =====================================================
1098
+
1099
+ Some advices for choosing the algorithm:
1100
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1101
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1102
+ e.g., DPM-Solver:
1103
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1104
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1105
+ skip_type='time_uniform', method='singlestep')
1106
+ e.g., DPM-Solver++:
1107
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1108
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1109
+ skip_type='time_uniform', method='singlestep')
1110
+ - For **guided sampling with large guidance scale** by DPMs:
1111
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1112
+ e.g.
1113
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1114
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1115
+ skip_type='time_uniform', method='multistep')
1116
+
1117
+ We support three types of `skip_type`:
1118
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1119
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1120
+ - 'time_quadratic': quadratic time for the time steps.
1121
+
1122
+ =====================================================
1123
+ Args:
1124
+ x: A pytorch tensor. The initial value at time `t_start`
1125
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1126
+ steps: A `int`. The total number of function evaluations (NFE).
1127
+ t_start: A `float`. The starting time of the sampling.
1128
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1129
+ t_end: A `float`. The ending time of the sampling.
1130
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1131
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1132
+ For discrete-time DPMs:
1133
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1134
+ For continuous-time DPMs:
1135
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1136
+ order: A `int`. The order of DPM-Solver.
1137
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1138
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1139
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1140
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1141
+
1142
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1143
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1144
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1145
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1146
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1147
+ it for high-resolutional images.
1148
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1149
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1150
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1151
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1152
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1153
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1154
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1155
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1156
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1157
+ Returns:
1158
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1159
+
1160
+ """
1161
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1162
+ t_T = self.noise_schedule.T if t_start is None else t_start
1163
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1164
+ if return_intermediate:
1165
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
1166
+ if self.correcting_xt_fn is not None:
1167
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
1168
+ device = x.device
1169
+ intermediates = []
1170
+ with torch.no_grad():
1171
+ if method == 'adaptive':
1172
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
1173
+ elif method == 'multistep':
1174
+ assert steps >= order
1175
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1176
+ assert timesteps.shape[0] - 1 == steps
1177
+ # Init the initial values.
1178
+ step = 0
1179
+ t = timesteps[step]
1180
+ t_prev_list = [t]
1181
+ model_prev_list = [self.model_fn(x, t)]
1182
+ if self.correcting_xt_fn is not None:
1183
+ x = self.correcting_xt_fn(x, t, step)
1184
+ if return_intermediate:
1185
+ intermediates.append(x)
1186
+ # Init the first `order` values by lower order multistep DPM-Solver.
1187
+ for step in range(1, order):
1188
+ t = timesteps[step]
1189
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type)
1190
+ if self.correcting_xt_fn is not None:
1191
+ x = self.correcting_xt_fn(x, t, step)
1192
+ if return_intermediate:
1193
+ intermediates.append(x)
1194
+ t_prev_list.append(t)
1195
+ model_prev_list.append(self.model_fn(x, t))
1196
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1197
+ for step in range(order, steps + 1):
1198
+ t = timesteps[step]
1199
+ # We only use lower order for steps < 10
1200
+ if lower_order_final and steps < 10:
1201
+ step_order = min(order, steps + 1 - step)
1202
+ else:
1203
+ step_order = order
1204
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type)
1205
+ if self.correcting_xt_fn is not None:
1206
+ x = self.correcting_xt_fn(x, t, step)
1207
+ if return_intermediate:
1208
+ intermediates.append(x)
1209
+ for i in range(order - 1):
1210
+ t_prev_list[i] = t_prev_list[i + 1]
1211
+ model_prev_list[i] = model_prev_list[i + 1]
1212
+ t_prev_list[-1] = t
1213
+ # We do not need to evaluate the final model value.
1214
+ if step < steps:
1215
+ model_prev_list[-1] = self.model_fn(x, t)
1216
+ elif method in ['singlestep', 'singlestep_fixed']:
1217
+ if method == 'singlestep':
1218
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
1219
+ elif method == 'singlestep_fixed':
1220
+ K = steps // order
1221
+ orders = [order,] * K
1222
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1223
+ for step, order in enumerate(orders):
1224
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1225
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device)
1226
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1227
+ h = lambda_inner[-1] - lambda_inner[0]
1228
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1229
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1230
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1231
+ if self.correcting_xt_fn is not None:
1232
+ x = self.correcting_xt_fn(x, t, step)
1233
+ if return_intermediate:
1234
+ intermediates.append(x)
1235
+ else:
1236
+ raise ValueError("Got wrong method {}".format(method))
1237
+ if denoise_to_zero:
1238
+ t = torch.ones((1,)).to(device) * t_0
1239
+ x = self.denoise_to_zero_fn(x, t)
1240
+ if self.correcting_xt_fn is not None:
1241
+ x = self.correcting_xt_fn(x, t, step + 1)
1242
+ if return_intermediate:
1243
+ intermediates.append(x)
1244
+ if return_intermediate:
1245
+ return x, intermediates
1246
+ else:
1247
+ return x
1248
+
1249
+
1250
+
1251
+ #############################################################
1252
+ # other utility functions
1253
+ #############################################################
1254
+
1255
+ def interpolate_fn(x, xp, yp):
1256
+ """
1257
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1258
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1259
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1260
+
1261
+ Args:
1262
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1263
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1264
+ yp: PyTorch tensor with shape [C, K].
1265
+ Returns:
1266
+ The function values f(x), with shape [N, C].
1267
+ """
1268
+ N, K = x.shape[0], xp.shape[1]
1269
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1270
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1271
+ x_idx = torch.argmin(x_indices, dim=2)
1272
+ cand_start_idx = x_idx - 1
1273
+ start_idx = torch.where(
1274
+ torch.eq(x_idx, 0),
1275
+ torch.tensor(1, device=x.device),
1276
+ torch.where(
1277
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1278
+ ),
1279
+ )
1280
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1281
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1282
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1283
+ start_idx2 = torch.where(
1284
+ torch.eq(x_idx, 0),
1285
+ torch.tensor(0, device=x.device),
1286
+ torch.where(
1287
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1288
+ ),
1289
+ )
1290
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1291
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1292
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1293
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1294
+ return cand
1295
+
1296
+
1297
+ def expand_dims(v, dims):
1298
+ """
1299
+ Expand the tensor `v` to the dim `dims`.
1300
+
1301
+ Args:
1302
+ `v`: a PyTorch tensor with shape [N].
1303
+ `dim`: a `int`.
1304
+ Returns:
1305
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1306
+ """
1307
+ return v[(...,) + (None,)*(dims - 1)]
model/ldm/editing_unet.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .audioldm import UNetModel
5
+
6
+ class EditingUNet(nn.Module):
7
+ def __init__(self, unet_config, use_flow_matching=True, velocity_bound=4.0):
8
+ super().__init__()
9
+ original_in_channels = unet_config.in_channels
10
+ config_dict = dict(unet_config)
11
+ config_dict['in_channels'] = original_in_channels * 2
12
+ self.unet = UNetModel(**config_dict)
13
+ self.original_in_channels = original_in_channels
14
+
15
+ self.use_flow_matching = use_flow_matching
16
+ if self.use_flow_matching:
17
+ # SOTA PRACTICE: Using a bounded activation is crucial for training stability
18
+ # and provides a strong, valid inductive bias. The velocity is not infinite.
19
+ # This prevents loss explosion and helps the model converge.
20
+ self.final_activation = nn.Hardtanh(min_val=-velocity_bound, max_val=velocity_bound)
21
+ print(f"✅ EditingUNet configured with Hardtanh(bound={velocity_bound}) for stable Flow Matching.")
22
+ else:
23
+ self.final_activation = None
24
+ print("✅ EditingUNet configured for standard DDPM noise prediction.")
25
+ def forward(self, noisy_target_latent, source_latent, context, timesteps, **kwargs):
26
+ # Handle batch size mismatch for classifier-free guidance
27
+ # If noisy_target_latent has 2x batch size (for CFG), replicate source_latent
28
+ if noisy_target_latent.shape[0] != source_latent.shape[0]:
29
+ if noisy_target_latent.shape[0] == 2 * source_latent.shape[0]:
30
+ # Replicate source_latent for CFG (unconditional + conditional)
31
+ source_latent = source_latent.repeat(2, 1, 1, 1)
32
+ else:
33
+ raise ValueError(f"Batch size mismatch: noisy_target_latent={noisy_target_latent.shape[0]}, source_latent={source_latent.shape[0]}")
34
+
35
+ # NO dtype casting here. Let the trainer handle it.
36
+ combined_latent = torch.cat([noisy_target_latent, source_latent], dim=1)
37
+
38
+ prediction = self.unet(
39
+ x=combined_latent,
40
+ timesteps=timesteps,
41
+ context=context
42
+ )
43
+
44
+ if self.final_activation is not None:
45
+ return self.final_activation(prediction)
46
+ else:
47
+ return prediction
model/ldm/exp_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "unet": {
4
+ "image_size": 32,
5
+ "in_channels": 8,
6
+ "out_channels": 8,
7
+ "model_channels": 256,
8
+ "attention_resolutions": [4, 2, 1],
9
+ "num_res_blocks": 2,
10
+ "channel_mult": [1, 2, 4, 4],
11
+ "num_heads": 8,
12
+ "use_spatial_transformer": true,
13
+ "transformer_depth": 2,
14
+ "context_dim": 768,
15
+ "use_checkpoint": true,
16
+ "legacy": false
17
+ }
18
+
19
+ }
20
+ }
model/ldm/linear_attention_block.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional
4
+ import math
5
+
6
+ # --- These are your existing, correct components ---
7
+ from diffusers.models.attention_processor import AttnProcessor2_0
8
+ from .customer_attention_processor import Attention, CustomLiteLAProcessor2_0
9
+ from diffusers.models.normalization import RMSNorm
10
+ from .attention import GLUMBConv # Using GLUMBConv from your attention.py
11
+ #from diffusers.models.attention_processor import FusedAttnProcessor2_0
12
+ class EditingTransformerBlock(nn.Module):
13
+ """
14
+ <<< PHIÊN BẢN CUỐI CÙNG >>>
15
+ Sử dụng kiến trúc Self-Attention + Cross-Attention, với Linear Attention Processor.
16
+ """
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ num_attention_heads: int,
21
+ attention_head_dim: int,
22
+ text_embed_dim: int,
23
+ mlp_ratio: float = 4.0,
24
+ use_adaln_single: bool = True,
25
+ ):
26
+ super().__init__()
27
+ self.use_adaln_single = use_adaln_single
28
+ inner_dim = num_attention_heads * attention_head_dim
29
+
30
+ # --- 1. Khối Self-Attention cho chuỗi âm thanh (đã ghép) ---
31
+ # Sử dụng CustomLiteLAProcessor2_0 cho self-attention
32
+ self.norm_self = RMSNorm(dim, eps=1e-6)
33
+ self.attn_self = Attention(
34
+ query_dim=dim,
35
+ heads=num_attention_heads,
36
+ dim_head=attention_head_dim,
37
+ out_dim=inner_dim,
38
+ # QUAN TRỌNG: Gán linear attention processor ở đây
39
+ processor=CustomLiteLAProcessor2_0()
40
+ )
41
+
42
+ # --- 2. Khối Cross-Attention cho âm thanh chú ý đến văn bản ---
43
+ # Đối với Cross-Attention, sử dụng attention tiêu chuẩn (SDPA) thường ổn định hơn
44
+ # và quan trọng hơn cho việc căn chỉnh. Linear attention có thể quá yếu ở đây.
45
+ # Tuy nhiên, nếu bạn vẫn muốn dùng linear, hãy đổi thành CustomLiteLAProcessor2_0.
46
+ self.norm_cross = RMSNorm(dim, eps=1e-6)
47
+ self.attn_cross = Attention(
48
+ query_dim=dim,
49
+ cross_attention_dim=text_embed_dim,
50
+ heads=num_attention_heads,
51
+ dim_head=attention_head_dim,
52
+ out_dim=inner_dim,
53
+ # KHUYẾN NGHỊ: Bắt đầu với processor chuẩn cho cross-attention
54
+ processor=AttnProcessor2_0() # Hoặc AttnProcessor() nếu PyTorch < 2.0
55
+ )
56
+
57
+ # --- 3. Khối Feed-Forward ---
58
+ self.norm_ff = RMSNorm(dim, eps=1e-6)
59
+ self.ff = GLUMBConv(
60
+ in_features=dim,
61
+ hidden_features=int(dim * mlp_ratio),
62
+ )
63
+
64
+ # --- 4. Điều kiện hóa AdaLN ---
65
+ if use_adaln_single:
66
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
67
+
68
+ def forward(
69
+ self,
70
+ hidden_states: torch.FloatTensor,
71
+ encoder_hidden_states: Optional[torch.FloatTensor],
72
+ temb: Optional[torch.FloatTensor],
73
+ use_checkpointing: bool = False,
74
+ ) -> torch.FloatTensor:
75
+
76
+ # Luồng xử lý không thay đổi so với phiên bản trước
77
+
78
+ # AdaLN setup
79
+ if self.use_adaln_single and temb is not None:
80
+ shift_self, scale_self, shift_cross, scale_cross, shift_ff, scale_ff = (
81
+ (self.scale_shift_table[None] + temb[:, None, :]).chunk(6, dim=1)
82
+ )
83
+ else:
84
+ scale_self, shift_self, scale_cross, shift_cross, scale_ff, shift_ff = (1.0, 0.0, 1.0, 0.0, 1.0, 0.0)
85
+
86
+ # --- 1. Self-Attention (với Linear Attention) ---
87
+ residual = hidden_states
88
+ norm_h = self.norm_self(hidden_states)
89
+ norm_h = norm_h * (1 + scale_self) + shift_self
90
+
91
+ # Processor sẽ tự động được gọi bên trong self.attn_self
92
+ attn_output, _ = self.attn_self(norm_h) # CustomLiteLAProcessor2_0 sẽ được dùng ở đây
93
+ hidden_states = attn_output + residual
94
+
95
+ # --- 2. Cross-Attention (với Attention chuẩn) ---
96
+ if encoder_hidden_states is not None:
97
+ residual = hidden_states
98
+ norm_h = self.norm_cross(hidden_states)
99
+ norm_h = norm_h * (1 + scale_cross) + shift_cross
100
+
101
+ # Cross-attention returns a tuple (output, attention_weights)
102
+ attn_output, _ = self.attn_cross(
103
+ hidden_states=norm_h,
104
+ encoder_hidden_states=encoder_hidden_states
105
+ )
106
+ hidden_states = attn_output + residual
107
+
108
+ # --- 3. Feed-Forward ---
109
+ residual = hidden_states
110
+ norm_h = self.norm_ff(hidden_states)
111
+ norm_h = norm_h * (1 + scale_ff) + shift_ff
112
+
113
+ ff_output = self.ff(norm_h)
114
+ hidden_states = ff_output + residual
115
+
116
+ return hidden_states
117
+
118
+
119
+
120
+
121
+ # class EditingTransformerBlock(nn.Module):
122
+ # """
123
+ # A CORRECTED, fully linear attention transformer block for editing tasks.
124
+ # It combines self-attention and cross-attention into a single, EFFICIENT
125
+ # linear self-attention operation on a concatenated sequence.
126
+ # """
127
+ # def __init__(
128
+ # self,
129
+ # dim,
130
+ # num_attention_heads,
131
+ # attention_head_dim,
132
+ # mlp_ratio=4.0,
133
+ # use_adaln_single=True,
134
+ # ):
135
+ # super().__init__()
136
+ # self.use_adaln_single = use_adaln_single
137
+ # self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6)
138
+
139
+ # # THE CRITICAL FIX: We use ONE attention block, initialized
140
+ # # with the LINEAR attention processor.
141
+ # self.attn = Attention(
142
+ # query_dim=dim,
143
+ # heads=num_attention_heads,
144
+ # dim_head=attention_head_dim,
145
+ # out_dim=dim,
146
+ # bias=True,
147
+ # processor=CustomLiteLAProcessor2_0(), # <--- THIS IS THE FIX
148
+ # )
149
+
150
+ # self.norm2 = RMSNorm(dim, elementwise_affine=False, eps=1e-6)
151
+ # self.ff = GLUMBConv(
152
+ # in_features=dim,
153
+ # hidden_features=int(dim * mlp_ratio),
154
+ # use_bias=(True, True, False),
155
+ # norm=(None, None, None),
156
+ # act=("silu", "silu", None),
157
+ # )
158
+
159
+ # if use_adaln_single:
160
+ # # This is simpler than the original 6-way split if we apply it once
161
+ # self.scale_shift_table = nn.Parameter(torch.randn(4, dim) / dim**0.5)
162
+
163
+ # def forward(
164
+ # self,
165
+ # hidden_states: torch.FloatTensor,
166
+ # encoder_hidden_states: Optional[torch.FloatTensor] = None,
167
+ # temb: Optional[torch.FloatTensor] = None,
168
+ # use_checkpointing: bool = False,
169
+ # ):
170
+ # hidden_states_len = hidden_states.shape[1]
171
+ # N = hidden_states.shape[0]
172
+ # # AdaLN-Single conditioning
173
+ # if self.use_adaln_single and temb is not None:
174
+ # shift_msa, scale_msa, shift_mlp, scale_mlp = (
175
+ # (self.scale_shift_table[None] + temb[:, None, :])
176
+ # .chunk(4, dim=1)
177
+ # )
178
+
179
+ # norm_hidden_states = self.norm1(hidden_states)
180
+ # if self.use_adaln_single and temb is not None:
181
+ # norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
182
+
183
+ # # --- UNIFIED ATTENTION LOGIC ---
184
+ # # The CustomLiteLAProcessor2_0 will treat this as one long sequence
185
+ # # for its Q, K, V projections. This is where self- and cross-attention merge.
186
+ # attn_input = torch.cat([norm_hidden_states, encoder_hidden_states], dim=1)
187
+
188
+ # # Define the forward pass for checkpointing
189
+ # def attn_forward(x):
190
+ # attn_output, _ = self.attn(hidden_states=x)
191
+ # return attn_output
192
+
193
+ # if use_checkpointing:
194
+ # attn_output_combined = torch.utils.checkpoint.checkpoint(attn_forward, attn_input, use_reentrant=False)
195
+ # else:
196
+ # attn_output_combined, _ = self.attn(hidden_states=attn_input)
197
+
198
+ # # Slice the output to get only the processed audio part
199
+ # attn_output = attn_output_combined[:, :hidden_states_len, :]
200
+ # # --- END UNIFIED ATTENTION ---
201
+
202
+ # hidden_states = hidden_states + attn_output
203
+
204
+ # # Feed-forward part
205
+ # norm_ff_states = self.norm2(hidden_states)
206
+ # if self.use_adaln_single and temb is not None:
207
+ # norm_ff_states = norm_ff_states * (1 + scale_mlp) + shift_mlp
208
+
209
+ # ff_output = self.ff(norm_ff_states)
210
+
211
+ # hidden_states = hidden_states + ff_output
212
+
213
+ # return hidden_states
214
+
215
+ class TimestepEmbedding(nn.Module):
216
+ """ Helper module for sinusoidal timestep embeddings. """
217
+ def __init__(self, dim, max_period=10000):
218
+ super().__init__()
219
+ self.dim = dim
220
+ self.max_period = max_period
221
+ def forward(self, t):
222
+ half = self.dim // 2
223
+ freqs = torch.exp(
224
+ -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
225
+ ).to(device=t.device)
226
+ args = t[:, None].float() * freqs[None]
227
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
228
+ if self.dim % 2:
229
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
230
+ return embedding
model/ldm/transformer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: model/ldm/transformer.py
2
+
3
+ import torch
4
+ from torch import nn
5
+ import math
6
+ from .linear_attention_block import EditingTransformerBlock
7
+ from diffusers.models.normalization import RMSNorm
8
+
9
+ class TimestepEmbedding(nn.Module):
10
+ """ Helper module for sinusoidal timestep embeddings. """
11
+ def __init__(self, dim, max_period=10000):
12
+ super().__init__()
13
+ self.dim = dim
14
+ self.max_period = max_period
15
+ def forward(self, t):
16
+ half = self.dim // 2
17
+ freqs = torch.exp(
18
+ -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
19
+ ).to(device=t.device)
20
+ args = t[:, None].float() * freqs[None]
21
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
22
+ if self.dim % 2:
23
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
24
+ return embedding
25
+
26
+
27
+ class EditingTransformer(nn.Module):
28
+ """
29
+ <<< THAY ĐỔI LỚN: KIẾN TRÚC ĐƯỢC CẬP NHẬT THEO PHƯƠNG PHÁP CỦA AUDIT >>>
30
+ """
31
+ def __init__(
32
+ self,
33
+ num_layers=12,
34
+ inner_dim=512,
35
+ num_heads=8,
36
+ attention_head_dim=64,
37
+ dcae_latent_channels=8,
38
+ text_embed_dim=768,
39
+ mlp_ratio=4.0,
40
+ ):
41
+ super().__init__()
42
+ self.inner_dim = inner_dim
43
+
44
+ # <<< THAY ĐỔI: Lớp project_in bây giờ sẽ xử lý cả latent nhiễu và latent nguồn.
45
+ self.proj_in = nn.Linear(dcae_latent_channels, inner_dim)
46
+
47
+ # Timestep embedding logic (không đổi)
48
+ self.time_embed = TimestepEmbedding(inner_dim)
49
+ self.time_mlp = nn.Sequential(
50
+ nn.Linear(inner_dim, inner_dim * 4),
51
+ nn.SiLU(),
52
+ nn.Linear(inner_dim * 4, inner_dim),
53
+ )
54
+
55
+ # <<< XÓA BỎ: Lớp context_proj cũ không còn cần thiết vì ta không hòa tan
56
+ # source_latent và text_embedding nữa.
57
+ # self.context_proj = nn.Linear(...)
58
+
59
+ # Các khối Transformer (không đổi, nhưng giờ sẽ hoạt động trên chuỗi dài hơn)
60
+ self.transformer_blocks = nn.ModuleList([
61
+ EditingTransformerBlock(
62
+ # <<< QUAN TRỌNG: Kích thước của khối transformer giờ là 2*inner_dim nếu bạn
63
+ # quyết định ghép các embedding lại. Tuy nhiên, kiến trúc self-attn rồi cross-attn
64
+ # sẽ hoạt động trên chuỗi dài hơn, nên dim của khối vẫn là inner_dim.
65
+ # Cách chúng ta làm là đưa chuỗi dài hơn vào.
66
+ dim=inner_dim,
67
+ num_attention_heads=num_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ text_embed_dim=text_embed_dim,
70
+ mlp_ratio=mlp_ratio,
71
+ ) for _ in range(num_layers)
72
+ ])
73
+
74
+ # Final output projection (không đổi)
75
+ self.norm_out = RMSNorm(inner_dim, eps=1e-6)
76
+ self.proj_out = nn.Linear(inner_dim, dcae_latent_channels)
77
+
78
+ self.apply(self._init_weights)
79
+
80
+ def _init_weights(self, module):
81
+ if isinstance(module, nn.Linear):
82
+ torch.nn.init.xavier_uniform_(module.weight)
83
+ if module.bias is not None:
84
+ nn.init.constant_(module.bias, 0)
85
+
86
+ # def forward(self, noisy_target_latent, source_latent, encoder_hidden_states, timestep, use_checkpointing=False):
87
+ # """
88
+ # <<< THAY ĐỔI LỚN: Luồng forward được viết lại hoàn toàn. >>>
89
+ # """
90
+ # b, c, h, w = noisy_target_latent.shape
91
+ # num_target_tokens = h * w
92
+
93
+ # # 1. Project cả latent nhiễu (zt) và latent nguồn (zin) thành các chuỗi token.
94
+ # hidden_states = self.proj_in(noisy_target_latent.permute(0, 2, 3, 1).reshape(b, num_target_tokens, c))
95
+ # source_states = self.proj_in(source_latent.permute(0, 2, 3, 1).reshape(b, num_target_tokens, c))
96
+
97
+ # # 2. Ghép hai chuỗi token lại với nhau theo chiều dài (sequence length).
98
+ # # Đây là cách làm tương đương với "ghép kênh" trong U-Net cho Transformer.
99
+ # input_sequence = torch.cat([hidden_states, source_states], dim=1)
100
+
101
+ # # 3. Tạo timestep embedding (không đổi).
102
+ # t_emb = self.time_mlp(self.time_embed(timestep).to(input_sequence.dtype))
103
+
104
+ # # 4. Đưa chuỗi dài đã ghép qua các khối Transformer.
105
+ # # `encoder_hidden_states` bây giờ CHỈ là embedding văn bản.
106
+ # processed_sequence = input_sequence
107
+ # for block in self.transformer_blocks:
108
+ # processed_sequence = block(
109
+ # hidden_states=processed_sequence,
110
+ # encoder_hidden_states=encoder_hidden_states,
111
+ # temb=t_emb,
112
+ # use_checkpointing=use_checkpointing
113
+ # )
114
+
115
+ # # 5. Tách lấy phần kết quả tương ứng với latent nhiễu ban đầu.
116
+ # output_hidden_states = processed_sequence[:, :num_target_tokens, :]
117
+
118
+ # # 6. Project ngược lại không gian latent.
119
+ # output_hidden_states = self.norm_out(output_hidden_states)
120
+ # output_latent_flat = self.proj_out(output_hidden_states)
121
+ # output_latent = output_latent_flat.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
122
+
123
+ # return output_latent
124
+ def forward(self, noisy_target_latent, source_latent, encoder_hidden_states, timestep, use_checkpointing=False):
125
+ """
126
+ <<< THAY ĐỔI LỚN: Triển khai chiến lược CHUNKING để xử lý chuỗi dài >>>
127
+ """
128
+ b, c, h, w = noisy_target_latent.shape
129
+ num_target_tokens = h * w
130
+
131
+ # 1. Project latent thành các chuỗi token dài (như cũ)
132
+ hidden_states = self.proj_in(noisy_target_latent.permute(0, 2, 3, 1).reshape(b, num_target_tokens, c))
133
+ source_states = self.proj_in(source_latent.permute(0, 2, 3, 1).reshape(b, num_target_tokens, c))
134
+
135
+ # Ghép lại thành một chuỗi đầu vào rất dài
136
+ input_sequence = torch.cat([hidden_states, source_states], dim=1)
137
+ full_seq_len = input_sequence.shape[1]
138
+
139
+ # Tạo timestep embedding
140
+ t_emb = self.time_mlp(self.time_embed(timestep).to(input_sequence.dtype))
141
+
142
+ # --- BẮT ĐẦU LOGIC CHUNKING ---
143
+
144
+ # 2. Định nghĩa các tham số cho chunking
145
+ CHUNK_SIZE = 1024 # Kích thước mỗi đoạn. Bạn có thể điều chỉnh con số này.
146
+ OVERLAP = CHUNK_SIZE // 4 # Độ gối lên nhau, ví dụ 256.
147
+
148
+ # Khởi tạo tensor đầu ra và tensor đếm để lấy trung bình vùng overlap
149
+ output_sequence = torch.zeros_like(input_sequence)
150
+ overlap_count = torch.zeros_like(input_sequence)
151
+
152
+ # Tạo một cửa sổ "hanning" để làm mượt các cạnh của chunk, giúp việc ghép nối tốt hơn
153
+ window = torch.hann_window(CHUNK_SIZE, device=input_sequence.device).view(1, -1, 1)
154
+
155
+ # 3. Vòng lặp xử lý từng chunk
156
+ start = 0
157
+ while start < full_seq_len:
158
+ end = min(start + CHUNK_SIZE, full_seq_len)
159
+ # Nếu chunk cuối cùng quá ngắn, lùi lại để đảm bảo đủ độ dài
160
+ if end - start < CHUNK_SIZE and start > 0:
161
+ start = full_seq_len - CHUNK_SIZE
162
+ end = full_seq_len
163
+
164
+ # Lấy ra một chunk từ chuỗi đầu vào
165
+ current_chunk = input_sequence[:, start:end, :]
166
+
167
+ # --- Xử lý chunk này qua tất cả các khối transformer ---
168
+ processed_chunk = current_chunk
169
+ for block in self.transformer_blocks:
170
+ # Lưu ý: use_checkpointing vẫn có thể áp dụng ở đây cho từng chunk
171
+ processed_chunk = block(
172
+ hidden_states=processed_chunk,
173
+ encoder_hidden_states=encoder_hidden_states,
174
+ temb=t_emb,
175
+ use_checkpointing=use_checkpointing
176
+ )
177
+
178
+ # 4. Cộng dồn kết quả vào tensor output
179
+ # Áp dụng cửa sổ để giảm hiệu ứng biên
180
+ output_sequence[:, start:end, :] += processed_chunk * window
181
+ overlap_count[:, start:end, :] += window
182
+
183
+ if end == full_seq_len:
184
+ break
185
+ start += (CHUNK_SIZE - OVERLAP)
186
+
187
+ # 5. Lấy trung bình các vùng gối lên nhau
188
+ # Thêm một epsilon nhỏ để tránh chia cho 0 ở những vùng không có overlap (mặc dù không nên xảy ra)
189
+ final_processed_sequence = output_sequence / (overlap_count + 1e-8)
190
+
191
+ # --- KẾT THÚC LOGIC CHUNKING ---
192
+
193
+ # 6. Tách lấy phần kết quả tương ứng với latent nhiễu
194
+ output_hidden_states = final_processed_sequence[:, :num_target_tokens, :]
195
+
196
+ # 7. Project ngược lại không gian latent
197
+ output_hidden_states = self.norm_out(output_hidden_states)
198
+ output_latent_flat = self.proj_out(output_hidden_states)
199
+ output_latent = output_latent_flat.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
200
+
201
+ return output_latent
model/scheduler.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: model/scheduler.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ class LinearNoiseScheduler:
6
+ def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
7
+ self.num_timesteps = num_timesteps
8
+
9
+ # Tạo lịch beta tuyến tính
10
+ self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
11
+
12
+ # Tính toán các giá trị alpha
13
+ self.alphas = 1.0 - self.betas
14
+ self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
15
+
16
+ # Các hệ số để thêm nhiễu (forward process)
17
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
18
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
19
+
20
+ # Các hệ số để loại bỏ nhiễu (reverse process / sampling)
21
+ self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
22
+ self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
23
+
24
+ # Khởi tạo một lịch trình timestep mặc định
25
+ self.timesteps = torch.arange(0, num_timesteps).flip(0)
26
+
27
+ def set_timesteps(self, num_inference_steps, device=None):
28
+ """
29
+ Thiết lập các timestep rời rạc được sử dụng cho chuỗi diffusion.
30
+ """
31
+ device_to_use = device if device is not None else self.betas.device
32
+ self.timesteps = torch.linspace(self.num_timesteps - 1, 0, num_inference_steps, dtype=torch.long, device=device_to_use)
33
+
34
+ def to(self, device):
35
+ """Chuyển tất cả các tensor của scheduler sang một thiết bị cụ thể."""
36
+ self.betas = self.betas.to(device)
37
+ self.alphas = self.alphas.to(device)
38
+ self.alphas_cumprod = self.alphas_cumprod.to(device)
39
+ self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
40
+ self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
41
+ self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(device)
42
+ self.posterior_variance = self.posterior_variance.to(device)
43
+ return self
44
+
45
+ def add_noise(self, original_samples, noise, timesteps):
46
+ """Thêm nhiễu vào mẫu gốc tại các bước thời gian t."""
47
+ sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod.to(timesteps.device)[timesteps].view(-1, 1, 1, 1)
48
+ sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod.to(timesteps.device)[timesteps].view(-1, 1, 1, 1)
49
+
50
+ noisy_samples = sqrt_alphas_cumprod_t * original_samples + sqrt_one_minus_alphas_cumprod_t * noise
51
+ return noisy_samples
52
+
53
+ def step(self, model_output, timestep, sample):
54
+ t = timestep
55
+ alpha_t = self.alphas[t]
56
+ alpha_bar_t = self.alphas_cumprod[t]
57
+ sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alphas_cumprod[t]
58
+ pred_original_sample = (sample - sqrt_one_minus_alpha_bar_t * model_output) / torch.sqrt(alpha_bar_t)
59
+ pred_original_sample = torch.clamp(pred_original_sample, -1., 1.)
60
+ if t == 0:
61
+ return pred_original_sample
62
+ alpha_bar_t_prev = self.alphas_cumprod_prev[t]
63
+ posterior_variance_t = self.posterior_variance[t]
64
+ pred_sample_direction = torch.sqrt(alpha_bar_t_prev) * self.betas[t] / (1. - alpha_bar_t)
65
+ prev_sample_mean = torch.sqrt(alpha_t) * (1. - alpha_bar_t_prev) / (1. - alpha_bar_t) * sample + pred_sample_direction * pred_original_sample
66
+ noise = torch.randn_like(model_output) if t > 0 else torch.zeros_like(model_output)
67
+ prev_sample = prev_sample_mean + torch.sqrt(posterior_variance_t) * noise
68
+ return prev_sample
69
+
70
+ def ddim_step(self, model_output, timestep, sample, eta=0.0, prev_timestep=None):
71
+ """
72
+ DDIM-style deterministic sampling step. eta=0.0 for DDIM, eta=1.0 for DDPM-like behavior.
73
+ """
74
+ if prev_timestep is None:
75
+ # Final step: return x0 prediction
76
+ alpha_bar_t = self.alphas_cumprod[timestep]
77
+ pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
78
+ pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
79
+ return pred_original_sample
80
+
81
+ t = timestep
82
+ prev_t = prev_timestep
83
+
84
+ alpha_bar_t = self.alphas_cumprod[t]
85
+ alpha_bar_prev = self.alphas_cumprod[prev_t]
86
+
87
+ # 1. Compute predicted original sample
88
+ pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
89
+ pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
90
+
91
+ # 2. Compute variance for random noise (only effective when eta > 0)
92
+ sigma_t = eta * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar_t) * (1 - alpha_bar_t / alpha_bar_prev))
93
+
94
+ # 3. Compute "direction pointing to x_t"
95
+ pred_sample_direction = torch.sqrt(1 - alpha_bar_prev - sigma_t**2) * model_output
96
+
97
+ # 4. Compute x_{t-1}
98
+ prev_sample = torch.sqrt(alpha_bar_prev) * pred_original_sample + pred_sample_direction
99
+
100
+ # 5. Add noise (if eta > 0)
101
+ if eta > 0:
102
+ noise = torch.randn_like(model_output)
103
+ prev_sample = prev_sample + sigma_t * noise
104
+
105
+ return prev_sample
106
+ def dpm_solver_multistep(self, model_output, timestep, sample, order=2, prev_timestep=None, prev_model_output=None):
107
+ if prev_timestep is None:
108
+ # Final step: return x0 prediction
109
+ alpha_bar_t = self.alphas_cumprod[timestep]
110
+ pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
111
+ return torch.clamp(pred_original_sample, -1.0, 1.0)
112
+
113
+ t = timestep
114
+ prev_t = prev_timestep
115
+
116
+ alpha_bar_t = self.alphas_cumprod[t]
117
+ alpha_bar_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.alphas_cumprod.new_tensor(1.0)
118
+
119
+ pred_original_sample = (sample - torch.sqrt(1 - alpha_bar_t) * model_output) / torch.sqrt(alpha_bar_t)
120
+ pred_original_sample = torch.clamp(pred_original_sample, -1.0, 1.0)
121
+
122
+ if order == 1 or prev_model_output is None:
123
+ prev_sample = torch.sqrt(alpha_bar_prev) * pred_original_sample + torch.sqrt(1 - alpha_bar_prev) * model_output
124
+ else:
125
+ lambda_t = 0.5 * torch.log(alpha_bar_t / (1 - alpha_bar_t))
126
+ lambda_prev = 0.5 * torch.log(alpha_bar_prev / (1 - alpha_bar_prev))
127
+ h = lambda_prev - lambda_t
128
+
129
+ prev_sample = (
130
+ torch.sqrt(alpha_bar_prev) * pred_original_sample +
131
+ torch.sqrt(1 - alpha_bar_prev) * (
132
+ model_output + h * (model_output - prev_model_output) / 2
133
+ )
134
+ )
135
+
136
+ return prev_sample
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio>=2.0.0
3
+ transformers>=4.30.0
4
+ gradio>=4.0.0
5
+ matplotlib>=3.5.0
6
+ numpy>=1.21.0
7
+ tqdm>=4.64.0
8
+ Pillow>=9.0.0
9
+ huggingface_hub>=0.16.0