Instructions to use ViTeX-Bench/ViTeX-Edit-14B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use ViTeX-Bench/ViTeX-Edit-14B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("ViTeX-Bench/ViTeX-Edit-14B", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import pywt | |
| import numpy as np | |
| from tqdm import tqdm | |
| def split_dwt(z_tensor_cpu, wavelet_name, dwt_level): | |
| all_clow_np = [] | |
| all_chigh_list = [] | |
| z_tensor_cpu = z_tensor_cpu.cpu().float() | |
| for i in range(z_tensor_cpu.shape[0]): | |
| z_numpy_ch = z_tensor_cpu[i].numpy() | |
| coeffs_ch = pywt.wavedec2(z_numpy_ch, wavelet_name, level=dwt_level, mode='symmetric', axes=(-2, -1)) | |
| clow_np = coeffs_ch[0] | |
| chigh_list = coeffs_ch[1:] | |
| all_clow_np.append(clow_np) | |
| all_chigh_list.append(chigh_list) | |
| all_clow_tensor = torch.from_numpy(np.stack(all_clow_np, axis=0)) | |
| return all_clow_tensor, all_chigh_list | |
| def reconstruct_dwt(c_low_tensor_cpu, c_high_coeffs, wavelet_name, original_shape): | |
| H_high, W_high = original_shape | |
| c_low_tensor_cpu = c_low_tensor_cpu.cpu().float() | |
| clow_np = c_low_tensor_cpu.numpy() | |
| if clow_np.ndim == 4 and clow_np.shape[0] == 1: | |
| clow_np = clow_np[0] | |
| coeffs_combined = [clow_np] + c_high_coeffs | |
| z_recon_np = pywt.waverec2(coeffs_combined, wavelet_name, mode='symmetric', axes=(-2, -1)) | |
| if z_recon_np.shape[-2] != H_high or z_recon_np.shape[-1] != W_high: | |
| z_recon_np = z_recon_np[..., :H_high, :W_high] | |
| z_recon_tensor = torch.from_numpy(z_recon_np) | |
| if z_recon_tensor.ndim == 3: | |
| z_recon_tensor = z_recon_tensor.unsqueeze(0) | |
| return z_recon_tensor | |
| def ses_search( | |
| base_latents, | |
| objective_reward_fn, | |
| total_eval_budget=30, | |
| popsize=10, | |
| k_elites=5, | |
| wavelet_name="db1", | |
| dwt_level=4, | |
| ): | |
| latent_h, latent_w = base_latents.shape[-2], base_latents.shape[-1] | |
| c_low_init, c_high_fixed_batch = split_dwt(base_latents, wavelet_name, dwt_level) | |
| c_high_fixed = c_high_fixed_batch[0] | |
| c_low_shape = c_low_init.shape[1:] | |
| mu = torch.zeros_like(c_low_init.view(-1).cpu()) | |
| sigma_sq = torch.ones_like(mu) * 1.0 | |
| best_overall = {"fitness": -float('inf'), "score": -float('inf'), "c_low": c_low_init[0]} | |
| eval_count = 0 | |
| elite_db = [] | |
| n_generations = (total_eval_budget // popsize) + 5 | |
| pbar = tqdm(total=total_eval_budget, desc="[SES] Searching", unit="img") | |
| for gen in range(n_generations): | |
| if eval_count >= total_eval_budget: break | |
| std = torch.sqrt(torch.clamp(sigma_sq, min=1e-9)) | |
| z_noise = torch.randn(popsize, mu.shape[0]) | |
| samples_flat = mu + z_noise * std | |
| samples_reshaped = samples_flat.view(popsize, *c_low_shape) | |
| batch_results = [] | |
| for i in range(popsize): | |
| if eval_count >= total_eval_budget: break | |
| c_low_sample = samples_reshaped[i].unsqueeze(0) | |
| z_recon = reconstruct_dwt(c_low_sample, c_high_fixed, wavelet_name, (latent_h, latent_w)) | |
| z_recon = z_recon.to(base_latents.device, dtype=base_latents.dtype) | |
| # img = pipeline_callback(z_recon) | |
| # score = scorer.get_score(img, prompt) | |
| score = objective_reward_fn(z_recon) | |
| res = { | |
| "score": score, | |
| "c_low": c_low_sample.cpu() | |
| } | |
| batch_results.append(res) | |
| if score > best_overall['score']: | |
| best_overall = res | |
| eval_count += 1 | |
| pbar.update(1) | |
| if not batch_results: break | |
| elite_db.extend(batch_results) | |
| elite_db.sort(key=lambda x: x['score'], reverse=True) | |
| elite_db = elite_db[:k_elites] | |
| elites_flat = torch.stack([x['c_low'].view(-1) for x in elite_db]) | |
| mu_new = torch.mean(elites_flat, dim=0) | |
| if len(elite_db) > 1: | |
| sigma_sq_new = torch.var(elites_flat, dim=0, unbiased=True) + 1e-7 | |
| else: | |
| sigma_sq_new = sigma_sq | |
| mu = mu_new | |
| sigma_sq = sigma_sq_new | |
| pbar.close() | |
| best_c_low = best_overall['c_low'] | |
| final_latents = reconstruct_dwt(best_c_low, c_high_fixed, wavelet_name, (latent_h, latent_w)) | |
| return final_latents.to(base_latents.device, dtype=base_latents.dtype) | |