| import torch |
| import pywt |
| import numpy as np |
| from tqdm import tqdm |
|
|
|
|
| def split_dwt(z_tensor_cpu, wavelet_name, dwt_level): |
| all_clow_np = [] |
| all_chigh_list = [] |
| z_tensor_cpu = z_tensor_cpu.cpu().float() |
| |
| for i in range(z_tensor_cpu.shape[0]): |
| z_numpy_ch = z_tensor_cpu[i].numpy() |
| |
| coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1)) |
| |
| clow_np = coeffs_ch[0] |
| chigh_list = coeffs_ch[1:] |
| |
| all_clow_np.append(clow_np) |
| all_chigh_list.append(chigh_list) |
| |
| all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0)) |
| return all_clow_tensor, all_chigh_list |
|
|
|
|
| def reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape): |
| H_high, W_high = original_shape |
| c_low_tensor_cpu = c_low_tensor_cpu.cpu().float() |
| |
| clow_np = c_low_tensor_cpu.numpy() |
| |
| if clow_np.ndim == 4 and clow_np.shape[0] == 1: |
| clow_np = clow_np[0] |
|
|
| coeffs_combined = [clow_np] + c_high_coeffs |
| z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1)) |
| if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high: |
| z_recon_np = z_recon_np[..., :H_high, :W_high] |
| z_recon_tensor = torch.from_numpy(z_recon_np) |
| if z_recon_tensor.ndim == 3: |
| z_recon_tensor = z_recon_tensor.unsqueeze(0) |
| return z_recon_tensor |
|
|
|
|
| def ses_search( |
| base_latents, |
| objective_reward_fn, |
| total_eval_budget=30, |
| popsize=10, |
| k_elites=5, |
| wavelet_name="db1", |
| dwt_level=4, |
| ): |
| latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1] |
| c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level) |
| c_high_fixed = c_high_fixed_batch[0] |
| c_low_shape = c_low_init.shape[1:] |
| mu = torch.zeros_like(c_low_init.view(-1).cpu()) |
| sigma_sq = torch.ones_like(mu) * 1.0 |
| |
| best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]} |
| eval_count = 0 |
| |
| elite_db = [] |
| n_generations = (total_eval_budget // popsize) + 5 |
| pbar = tqdm(total=total_eval_budget, desc="[SES] Searching", unit="img") |
|
|
| for gen in range(n_generations): |
| if eval_count >= total_eval_budget: break |
| |
| std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9)) |
| z_noise = torch.randn(popsize, mu.shape[0]) |
| samples_flat = mu + z_noise * std |
| samples_reshaped = samples_flat.view(popsize, *c_low_shape) |
| |
| batch_results = [] |
| |
| for i in range(popsize): |
| if eval_count >= total_eval_budget: break |
| |
| c_low_sample = samples_reshaped[i].unsqueeze(0) |
| z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w)) |
| z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype) |
| |
|
|
| |
| score = objective_reward_fn(z_recon) |
| res = { |
| "score": score, |
| "c_low": c_low_sample.cpu() |
| } |
| batch_results.append(res) |
| if score > best_overall['score']: |
| best_overall = res |
| |
| eval_count += 1 |
| pbar.update(1) |
| |
| if not batch_results: break |
| elite_db.extend(batch_results) |
| elite_db.sort(key=lambda x: x['score'], reverse=True) |
| elite_db = elite_db[:k_elites] |
| elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db]) |
| mu_new = torch.mean(elites_flat, dim=0) |
| |
| if len(elite_db) > 1: |
| sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7 |
| else: |
| sigma_sq_new = sigma_sq |
| mu = mu_new |
| sigma_sq = sigma_sq_new |
| pbar.close() |
| best_c_low = best_overall['c_low'] |
| final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w)) |
| |
| return final_latents.to(base_latents.device, dtype=base_latents.dtype) |
|
|