TSEditor / models /Tiffusion /control copy.py
PeterYu's picture
update
2875fe6
import torch
def sum_guidance(
x: torch.Tensor,
t: torch.Tensor,
target_sum: torch.Tensor,
sigma: float = 1.0,
gradient_scale: float = 1.0,
segments: tuple = None
):
"""Enhanced regression guidance with stronger gradients"""
x_with_grad = x
# if x_with_grad.shape[1] < x_with_grad.shape[2]: # [B, C, L]
# current_sum = x_with_grad[:, 0]
# current_sum = current_sum / 2 + 0.5 # [-1, 1 to 0, 1]
# current_sum = current_sum.sum(dim=1)
# if segments: # 使用segments来指定计算的区间
# for i, (start_idx, end_idx) in enumerate(segments):
# if i==0:
# current_sum = current_sum[:,0, start_idx:end_idx]
# else:
# current_sum += current_sum[:,0, start_idx:end_idx]
# assert False, "Not implemented yet"
# else: # [B, L, C]
# print(x_with_grad.shape)
current_sum = x_with_grad[:, :, 0]
current_sum = current_sum / 2 + 0.5
current_sum = current_sum.sum(dim=1)
if segments: # 使用segments来指定计算的区间
for i, (start_idx, end_idx) in enumerate(segments):
if i==0:
current_sum = current_sum[:, start_idx:end_idx,0]
else:
current_sum += current_sum[:, start_idx:end_idx,0]
# 使用更小的sigma来增强梯度
# sigma = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(current_sum)
# sigma = sigma / gradient_scale # 缩小sigma增强梯度
if sigma == 0:
pred_std = torch.ones_like(current_sum)
else:
pred_std = torch.ones_like(current_sum) * sigma
# 使用指数函数增强梯度
# log_prob = torch.exp(-0.5 * (target_sum - current_sum)**2 / (pred_std**2))
log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \
(target_sum - current_sum)**2 / (2 * pred_std**2)
return log_prob.mean()
def peak_guidance(
x: torch.Tensor,
t: torch.Tensor,
peak_points: list,
window_size: int = 5,
alpha_1: float = 1.2,
sigma: float = 1.0,
gradient_scale: float = 1.0 # 新增梯度缩放参数
):
x_with_grad = x
log_prob = 0
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
signal = signal / 2 + 0.5
for x_coord in peak_points:
# 全局均值条件
# global_mean = signal.mean(dim=1, keepdim=True)
# peak_diff = signal[:, x_coord] - global_mean.squeeze()
# 局部窗口均值条件
half_window = window_size // 2
start_idx = max(0, x_coord - half_window)
end_idx = min(signal.shape[1], x_coord + half_window + 1)
# local_mean = signal[:, start_idx:end_idx].mean(dim=1)
# local_diff = signal[:, x_coord] - local_mean * alpha_1
# local_mean not include the peak point
local_mean = (signal[:, start_idx:end_idx].sum(dim=1) - signal[:, x_coord]) / (end_idx - start_idx - 1)
local_diff = (local_mean * alpha_1 - signal[:, x_coord]).mean()
# # 动态调整sigma以增强梯度
# sigma_t = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(peak_diff)
# sigma_t = sigma_t / gradient_scale # 缩小sigma增强梯度
# sigma_t = torch.ones_like(peak_diff) * sigma if sigma > 0 else torch.ones_like(peak_diff)
# 使用指数函数增强梯度
# log_prob += torch.exp(-0.5 * ((peak_diff - 2 * sigma)**2) / (sigma_t**2)).mean()
# log_prob += torch.exp(-0.5 * (local_diff**2) / (sigma_t**2)).mean()
# 不使用指数函数增强梯度
# log_prob = -0.5 * torch.log(2 * torch.pi * pred_std**2) - \
# (target_sum - current_sum)**2 / (2 * pred_std**2)
# if sigma == 0:
# pred_std = torch.ones_like(local_mean)
# else:
# pred_std = torch.ones_like(local_mean) * sigma
log_prob += - (local_diff**2) / (2 * sigma**2)
return log_prob.mean()
def bar_guidance(
x: torch.Tensor,
t: torch.Tensor,
bar_regions: list,
sigma: float = 1.0,
gradient_scale: float = 1.0
):
x_with_grad = x
log_prob = 0
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
signal = signal / 2 + 0.5
for start_idx, end_idx, target_value in bar_regions:
region_mean = signal[:, start_idx:end_idx].mean(dim=1)
# sigma_t = torch.log(t) / 5 if t.mean() > 0 else torch.ones_like(region_mean)
# sigma_t = sigma_t / gradient_scale
sigma_t = torch.ones_like(region_mean) * sigma if sigma > 0 else torch.ones_like(region_mean)
# 使用指数函数增强梯度
log_prob += torch.exp(-0.5 * ((region_mean - target_value)**2) / (sigma_t**2)).mean()
return log_prob
def frequency_guidance(
x: torch.Tensor,
t: torch.Tensor,
target_freq: float,
freq_weight: float = 1.0,
gradient_scale: float = 1.0
):
x_with_grad = x
if x_with_grad.shape[1] < x_with_grad.shape[2]:
signal = x_with_grad[:, 0]
else:
signal = x_with_grad[:, :, 0]
fft_signal = torch.fft.rfft(signal, dim=1)
freqs = torch.fft.rfftfreq(signal.shape[1], d=1.0)
# 使用更窄的高斯窗口增强特定频率
freq_window = torch.exp(-((freqs - target_freq)**2) / (2 * (0.1/gradient_scale)**2))
freq_window = freq_window.to(x.device)
magnitude = torch.abs(fft_signal) * freq_window[None, :]
# 使用指数函数增强梯度
return torch.exp(freq_weight * magnitude.mean())
def get_time_dependent_weights(t, num_timesteps):
"""
根据时间步长动态调整控制信号的权重
较早的时间步长使用更大的权重
"""
progress = t.float() / num_timesteps
# 在早期时间步长使用更大的权重
weight_scale = torch.exp(-5 * progress)
return weight_scale