Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
9f4fc9f
1
Parent(s):
ca7dd21
add Immiscible training code
Browse files
speech/cosyvoice/flow/flow_matching.py
CHANGED
|
@@ -32,6 +32,8 @@ class ConditionalCFM(BASECFM):
|
|
| 32 |
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
| 33 |
# Just change the architecture of the estimator here
|
| 34 |
self.estimator = estimator
|
|
|
|
|
|
|
| 35 |
|
| 36 |
@torch.inference_mode()
|
| 37 |
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
|
@@ -169,14 +171,73 @@ class ConditionalCFM(BASECFM):
|
|
| 169 |
y: conditional flow
|
| 170 |
shape: (batch_size, n_feats, mel_timesteps)
|
| 171 |
"""
|
| 172 |
-
b,
|
| 173 |
|
| 174 |
# random timestep
|
| 175 |
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 176 |
if self.t_scheduler == 'cosine':
|
| 177 |
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 182 |
u = x1 - (1 - self.sigma_min) * z
|
|
@@ -187,13 +248,7 @@ class ConditionalCFM(BASECFM):
|
|
| 187 |
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 188 |
spks = spks * cfg_mask.view(-1, 1)
|
| 189 |
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 190 |
-
|
| 191 |
-
# print('mask shape: ', mask.shape)
|
| 192 |
-
# print('mu shape: ', mu.shape)
|
| 193 |
-
# print('t shape: ', t.shape)
|
| 194 |
-
# print('spks shape: ', spks.shape)
|
| 195 |
-
# print('cond shape: ', cond.shape)
|
| 196 |
-
# print('streaming: ', streaming)
|
| 197 |
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
| 198 |
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 199 |
return loss, y
|
|
|
|
| 32 |
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
| 33 |
# Just change the architecture of the estimator here
|
| 34 |
self.estimator = estimator
|
| 35 |
+
self.use_immiscible = cfm_params.use_immiscible
|
| 36 |
+
self.immiscible_k = cfm_params.immiscible_k
|
| 37 |
|
| 38 |
@torch.inference_mode()
|
| 39 |
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
|
|
|
| 171 |
y: conditional flow
|
| 172 |
shape: (batch_size, n_feats, mel_timesteps)
|
| 173 |
"""
|
| 174 |
+
b, d, T = mu.shape
|
| 175 |
|
| 176 |
# random timestep
|
| 177 |
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 178 |
if self.t_scheduler == 'cosine':
|
| 179 |
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
print(f"\n=== Immiscible Diffusion Debug ===")
|
| 183 |
+
print(f"x1 shape: {x1.shape}")
|
| 184 |
+
print(f"mu shape: {mu.shape}")
|
| 185 |
+
print(f"t shape: {t.shape}")
|
| 186 |
+
print(f"Device: {x1.device}")
|
| 187 |
+
print(f"Dtype: {x1.dtype}")
|
| 188 |
+
|
| 189 |
+
# Apply immiscible diffusion with KNN
|
| 190 |
+
if self.use_immiscible:
|
| 191 |
+
k = getattr(self, 'immiscible_k', 4)
|
| 192 |
+
|
| 193 |
+
# Generate k noise samples for each data point
|
| 194 |
+
z_candidates = torch.randn(b, k, d, T, device=x1.device, dtype=x1.dtype)
|
| 195 |
+
|
| 196 |
+
print(f"z_candidates shape: {z_candidates.shape}")
|
| 197 |
+
print(f"z_candidates stats - mean: {z_candidates.mean():.4f}, std: {z_candidates.std():.4f}")
|
| 198 |
+
|
| 199 |
+
# Flatten for distance computation
|
| 200 |
+
x1_flat = x1.flatten(start_dim=1).to(torch.float16)
|
| 201 |
+
z_candidates_flat = z_candidates.flatten(start_dim=2).to(torch.float16)
|
| 202 |
+
|
| 203 |
+
print(f"x1_flat shape: {x1_flat.shape}")
|
| 204 |
+
print(f"z_candidates_flat shape: {z_candidates_flat.shape}")
|
| 205 |
+
|
| 206 |
+
# Calculate distances
|
| 207 |
+
distances = torch.norm(x1_flat.unsqueeze(1) - z_candidates_flat, dim=2)
|
| 208 |
+
|
| 209 |
+
print(f"distances shape: {distances.shape}")
|
| 210 |
+
print(f"distances stats - mean: {distances.mean():.4f}, std: {distances.std():.4f}")
|
| 211 |
+
print(f"distances min: {distances.min():.4f}, max: {distances.max():.4f}")
|
| 212 |
+
|
| 213 |
+
# Pick the nearest noise for each data point
|
| 214 |
+
min_distances, min_indices = torch.min(distances, dim=1)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
print(f"min_indices: {min_indices[:10]}") # First 10 indices
|
| 218 |
+
print(f"min_distances stats - mean: {min_distances.mean():.4f}, std: {min_distances.std():.4f}")
|
| 219 |
+
|
| 220 |
+
# Gather the selected noise samples
|
| 221 |
+
z = torch.gather(
|
| 222 |
+
z_candidates,
|
| 223 |
+
1,
|
| 224 |
+
min_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(-1, 1, d, T)
|
| 225 |
+
)[:, 0, :, :]
|
| 226 |
+
|
| 227 |
+
print(f"Selected z shape: {z.shape}")
|
| 228 |
+
print(f"Selected z stats - mean: {z.mean():.4f}, std: {z.std():.4f}")
|
| 229 |
+
|
| 230 |
+
# Calculate distance reduction
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
orig_distance = distances[:, 0].mean()
|
| 233 |
+
selected_distance = min_distances.mean()
|
| 234 |
+
reduction_rate = (orig_distance - selected_distance) / orig_distance
|
| 235 |
+
print(f"Distance reduction: {reduction_rate:.3%}")
|
| 236 |
+
print(f"Original distance: {orig_distance:.4f}")
|
| 237 |
+
print(f"Selected distance: {selected_distance:.4f}")
|
| 238 |
+
else:
|
| 239 |
+
# sample noise p(x_0)
|
| 240 |
+
z = torch.randn_like(x1)
|
| 241 |
|
| 242 |
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 243 |
u = x1 - (1 - self.sigma_min) * z
|
|
|
|
| 248 |
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 249 |
spks = spks * cfg_mask.view(-1, 1)
|
| 250 |
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 251 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
| 253 |
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 254 |
return loss, y
|