Update app.py
Browse files
app.py
CHANGED
|
@@ -54,16 +54,21 @@ else:
|
|
| 54 |
print('noise prediction')
|
| 55 |
scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
noise = torch.
|
| 61 |
-
timesteps = torch.randint(
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
|
|
|
| 67 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 68 |
"""
|
| 69 |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
|
@@ -112,6 +117,7 @@ def sample_diffusion(mixture, timbre, ddim_steps=50, eta=0, seed=2023, guidance_
|
|
| 112 |
|
| 113 |
@spaces.GPU
|
| 114 |
def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale):
|
|
|
|
| 115 |
with torch.no_grad():
|
| 116 |
# mixture, _ = librosa.load(gt_file_input, sr=sample_rate)
|
| 117 |
mixture, sr = torchaudio.load(gt_file_input)
|
|
|
|
| 54 |
print('noise prediction')
|
| 55 |
scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers'])
|
| 56 |
|
| 57 |
+
@spaces.GPU
|
| 58 |
+
def reset_scheduler_dtype():
|
| 59 |
+
latents = torch.randn((1, 128, 128), device="cuda")
|
| 60 |
+
noise = torch.randn_like(latents)
|
| 61 |
+
timesteps = torch.randint(
|
| 62 |
+
0,
|
| 63 |
+
scheduler.config.num_train_timesteps,
|
| 64 |
+
(latents.shape[0],),
|
| 65 |
+
device=latents.device
|
| 66 |
+
)
|
| 67 |
+
_ = scheduler.add_noise(latents, noise, timesteps)
|
| 68 |
+
return "Scheduler dtype reset completed."
|
| 69 |
|
| 70 |
|
| 71 |
+
@spaces.GPU
|
| 72 |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 73 |
"""
|
| 74 |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
|
|
|
| 117 |
|
| 118 |
@spaces.GPU
|
| 119 |
def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale):
|
| 120 |
+
reset_scheduler_dtype()
|
| 121 |
with torch.no_grad():
|
| 122 |
# mixture, _ = librosa.load(gt_file_input, sr=sample_rate)
|
| 123 |
mixture, sr = torchaudio.load(gt_file_input)
|