Spaces:
Sleeping
Sleeping
| import torch | |
| def sum_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| target_sum: torch.Tensor, | |
| sigma: float = 1.0, | |
| gradient_scale: float = 1.0, | |
| segments: tuple = None | |
| ): | |
| """Enhanced regression guidance with stronger gradients""" | |
| x_with_grad = x | |
| current_sum = x_with_grad[:, :, 0] | |
| current_sum = current_sum / 2 + 0.5 | |
| current_sum = current_sum.sum(dim=1) | |
| if segments: | |
| for i, (start_idx, end_idx) in enumerate(segments): | |
| if i==0: | |
| current_sum = current_sum[:, start_idx:end_idx,0] | |
| else: | |
| current_sum += current_sum[:, start_idx:end_idx,0] | |
| if sigma == 0: | |
| pred_std = torch.ones_like(current_sum) | |
| else: | |
| pred_std = torch.ones_like(current_sum) * sigma | |
| log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \ | |
| (target_sum - current_sum)**2 / (2 * pred_std**2) | |
| return log_prob.mean() | |
| def peak_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| peak_points: list, | |
| window_size: int = 5, | |
| alpha_1: float = 1.2, | |
| sigma: float = 1.0, | |
| gradient_scale: float = 1.0 # 新增梯度缩放参数 | |
| ): | |
| x_with_grad = x | |
| log_prob = 0 | |
| if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
| signal = x_with_grad[:, 0] | |
| else: | |
| signal = x_with_grad[:, :, 0] | |
| signal = signal / 2 + 0.5 | |
| for x_coord in peak_points: | |
| half_window = window_size // 2 | |
| start_idx = max(0, x_coord - half_window) | |
| end_idx = min(signal.shape[1], x_coord + half_window + 1) | |
| local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1) | |
| local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean() | |
| log_prob += - (local_diff**2) / (2 * sigma**2) | |
| return log_prob.mean() | |
| def bar_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| bar_regions: list, | |
| sigma: float = 1.0, | |
| gradient_scale: float = 1.0 | |
| ): | |
| x_with_grad = x | |
| log_prob = 0 | |
| if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
| signal = x_with_grad[:, 0] | |
| else: | |
| signal = x_with_grad[:, :, 0] | |
| signal = signal / 2 + 0.5 | |
| for start_idx, end_idx, target_value in bar_regions: | |
| region_mean = signal[:, start_idx:end_idx].mean(dim=1) | |
| sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean) | |
| log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean() | |
| return log_prob | |
| def frequency_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| target_freq: float, | |
| freq_weight: float = 1.0, | |
| gradient_scale: float = 1.0 | |
| ): | |
| x_with_grad = x | |
| if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
| signal = x_with_grad[:, 0] | |
| else: | |
| signal = x_with_grad[:, :, 0] | |
| fft_signal = torch.fft.rfft(signal, dim=1) | |
| freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0) | |
| freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2)) | |
| freq_window = freq_window.to(x.device) | |
| magnitude = torch.abs(fft_signal) * freq_window[None, :] | |
| return torch.exp(freq_weight * magnitude.mean()) | |
| def get_time_dependent_weights(t, num_timesteps): | |
| """ | |
| 根据时间步长动态调整控制信号的权重 | |
| 较早的时间步长使用更大的权重 | |
| """ | |
| progress = t.float() / num_timesteps | |
| # 在早期时间步长使用更大的权重 | |
| weight_scale = torch.exp(-5 * progress) | |
| return weight_scale | |