Spaces:
Sleeping
Sleeping
| import torch | |
| def sum_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| target_sum: torch.Tensor, | |
| sigma: float = 1.0, | |
| gradient_scale: float = 1.0, | |
| segments: tuple = None | |
| ): | |
| """Enhanced regression guidance with stronger gradients""" | |
| x_with_grad = x | |
| # if x_with_grad.shape[1] < x_with_grad.shape[2]: # [B, C, L] | |
| # current_sum = x_with_grad[:, 0] | |
| # current_sum = current_sum / 2 + 0.5 # [-1, 1 to 0, 1] | |
| # current_sum = current_sum.sum(dim=1) | |
| # if segments: # 使用segments来指定计算的区间 | |
| # for i, (start_idx, end_idx) in enumerate(segments): | |
| # if i==0: | |
| # current_sum = current_sum[:,0, start_idx:end_idx] | |
| # else: | |
| # current_sum += current_sum[:,0, start_idx:end_idx] | |
| # assert False, "Not implemented yet" | |
| # else: # [B, L, C] | |
| # print(x_with_grad.shape) | |
| current_sum = x_with_grad[:, :, 0] | |
| current_sum = current_sum / 2 + 0.5 | |
| current_sum = current_sum.sum(dim=1) | |
| if segments: # 使用segments来指定计算的区间 | |
| for i, (start_idx, end_idx) in enumerate(segments): | |
| if i==0: | |
| current_sum = current_sum[:, start_idx:end_idx,0] | |
| else: | |
| current_sum += current_sum[:, start_idx:end_idx,0] | |
| # 使用更小的sigma来增强梯度 | |
| # sigma = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(current_sum) | |
| # sigma = sigma / gradient_scale # 缩小sigma增强梯度 | |
| if sigma == 0: | |
| pred_std = torch.ones_like(current_sum) | |
| else: | |
| pred_std = torch.ones_like(current_sum) * sigma | |
| # 使用指数函数增强梯度 | |
| # log_prob = torch.exp(-0.5 * (target_sum - current_sum)**2 / (pred_std**2)) | |
| log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \ | |
| (target_sum - current_sum)**2 / (2 * pred_std**2) | |
| return log_prob.mean() | |
| def peak_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| peak_points: list, | |
| window_size: int = 5, | |
| alpha_1: float = 1.2, | |
| sigma: float = 1.0, | |
| gradient_scale: float = 1.0 # 新增梯度缩放参数 | |
| ): | |
| x_with_grad = x | |
| log_prob = 0 | |
| if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
| signal = x_with_grad[:, 0] | |
| else: | |
| signal = x_with_grad[:, :, 0] | |
| signal = signal / 2 + 0.5 | |
| for x_coord in peak_points: | |
| # 全局均值条件 | |
| # global_mean = signal.mean(dim=1, keepdim=True) | |
| # peak_diff = signal[:, x_coord] - global_mean.squeeze() | |
| # 局部窗口均值条件 | |
| half_window = window_size // 2 | |
| start_idx = max(0, x_coord - half_window) | |
| end_idx = min(signal.shape[1], x_coord + half_window + 1) | |
| # local_mean = signal[:, start_idx:end_idx].mean(dim=1) | |
| # local_diff = signal[:, x_coord] - local_mean * alpha_1 | |
| # local_mean not include the peak point | |
| local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1) | |
| local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean() | |
| # # 动态调整sigma以增强梯度 | |
| # sigma_t = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(peak_diff) | |
| # sigma_t = sigma_t / gradient_scale # 缩小sigma增强梯度 | |
| # sigma_t = torch.ones_like(peak_diff) * sigma if sigma > 0 else torch.ones_like(peak_diff) | |
| # 使用指数函数增强梯度 | |
| # log_prob += torch.exp(-0.5 * ((peak_diff - 2 * sigma)**2) / (sigma_t**2)).mean() | |
| # log_prob += torch.exp(-0.5 * (local_diff**2) / (sigma_t**2)).mean() | |
| # 不使用指数函数增强梯度 | |
| # log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \ | |
| # (target_sum - current_sum)**2 / (2 * pred_std**2) | |
| # if sigma == 0: | |
| # pred_std = torch.ones_like(local_mean) | |
| # else: | |
| # pred_std = torch.ones_like(local_mean) * sigma | |
| log_prob += - (local_diff**2) / (2 * sigma**2) | |
| return log_prob.mean() | |
| def bar_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| bar_regions: list, | |
| sigma: float = 1.0, | |
| gradient_scale: float = 1.0 | |
| ): | |
| x_with_grad = x | |
| log_prob = 0 | |
| if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
| signal = x_with_grad[:, 0] | |
| else: | |
| signal = x_with_grad[:, :, 0] | |
| signal = signal / 2 + 0.5 | |
| for start_idx, end_idx, target_value in bar_regions: | |
| region_mean = signal[:, start_idx:end_idx].mean(dim=1) | |
| # sigma_t = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(region_mean) | |
| # sigma_t = sigma_t / gradient_scale | |
| sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean) | |
| # 使用指数函数增强梯度 | |
| log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean() | |
| return log_prob | |
| def frequency_guidance( | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| target_freq: float, | |
| freq_weight: float = 1.0, | |
| gradient_scale: float = 1.0 | |
| ): | |
| x_with_grad = x | |
| if x_with_grad.shape[1] < x_with_grad.shape[2]: | |
| signal = x_with_grad[:, 0] | |
| else: | |
| signal = x_with_grad[:, :, 0] | |
| fft_signal = torch.fft.rfft(signal, dim=1) | |
| freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0) | |
| # 使用更窄的高斯窗口增强特定频率 | |
| freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2)) | |
| freq_window = freq_window.to(x.device) | |
| magnitude = torch.abs(fft_signal) * freq_window[None, :] | |
| # 使用指数函数增强梯度 | |
| return torch.exp(freq_weight * magnitude.mean()) | |
| def get_time_dependent_weights(t, num_timesteps): | |
| """ | |
| 根据时间步长动态调整控制信号的权重 | |
| 较早的时间步长使用更大的权重 | |
| """ | |
| progress = t.float() / num_timesteps | |
| # 在早期时间步长使用更大的权重 | |
| weight_scale = torch.exp(-5 * progress) | |
| return weight_scale | |