Instructions to use ViTeX-Bench/ViTeX-Edit-14B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use ViTeX-Bench/ViTeX-Edit-14B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("ViTeX-Bench/ViTeX-Edit-14B", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import os, torch | |
| from tqdm import tqdm | |
| from accelerate import Accelerator | |
| from .training_module import DiffusionTrainingModule | |
| from .logger import ModelLogger | |
| def launch_training_task( | |
| accelerator: Accelerator, | |
| dataset: torch.utils.data.Dataset, | |
| model: DiffusionTrainingModule, | |
| model_logger: ModelLogger, | |
| learning_rate: float = 1e-5, | |
| weight_decay: float = 1e-2, | |
| num_workers: int = 1, | |
| save_steps: int = None, | |
| num_epochs: int = 1, | |
| args = None, | |
| ): | |
| if args is not None: | |
| learning_rate = args.learning_rate | |
| weight_decay = args.weight_decay | |
| num_workers = args.dataset_num_workers | |
| save_steps = args.save_steps | |
| num_epochs = args.num_epochs | |
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) | |
| scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) | |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) | |
| model.to(device=accelerator.device) | |
| # Exclude VAE from DeepSpeed ZeRO-3 wrapping to avoid compatibility issues | |
| # Store VAE outside the module tree so DeepSpeed doesn't touch it | |
| vae_module = getattr(model.pipe, 'vae', None) | |
| if vae_module is not None: | |
| del model.pipe._modules['vae'] | |
| model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) | |
| if vae_module is not None: | |
| vae_module.to(accelerator.device) | |
| # Store VAE as a non-module attribute so pipeline code can still use pipe.vae | |
| pipe = model.module.pipe if hasattr(model, 'module') else model.pipe | |
| # Use object.__setattr__ to bypass nn.Module's __setattr__ which would register it as a submodule | |
| object.__setattr__(pipe, 'vae', vae_module) | |
| initialize_deepspeed_gradient_checkpointing(accelerator) | |
| # Training log file | |
| log_path = os.path.join(model_logger.output_path, "training_log.txt") | |
| if accelerator.is_main_process: | |
| os.makedirs(model_logger.output_path, exist_ok=True) | |
| log_file = open(log_path, "a") | |
| log_file.write(f"Training started. Epochs: {num_epochs}, LR: {learning_rate}, Steps/epoch: {len(dataloader)}\n") | |
| log_file.flush() | |
| else: | |
| log_file = None | |
| total_target = num_epochs * len(dataloader) | |
| reached_target = False | |
| for epoch_id in range(num_epochs): | |
| if reached_target: | |
| break | |
| progress = tqdm( | |
| total=total_target, | |
| initial=model_logger.num_steps, | |
| desc=f"Epoch {epoch_id+1}/{num_epochs}", | |
| ) | |
| for step_id, data in enumerate(dataloader): | |
| if model_logger.num_steps >= total_target: | |
| reached_target = True | |
| break | |
| with accelerator.accumulate(model): | |
| optimizer.zero_grad() | |
| if dataset.load_from_cache: | |
| loss = model({}, inputs=data) | |
| else: | |
| loss = model(data) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| model_logger.on_step_end(accelerator, model, save_steps, loss=loss) | |
| scheduler.step() | |
| # Log loss | |
| loss_val = loss.item() | |
| progress.update(1) | |
| progress.set_postfix(loss=f"{loss_val:.4f}") | |
| if accelerator.is_main_process and log_file is not None and (model_logger.num_steps % 10 == 0 or model_logger.num_steps <= 5): | |
| log_file.write(f"epoch={epoch_id+1} step={model_logger.num_steps} loss={loss_val:.6f}\n") | |
| log_file.flush() | |
| progress.close() | |
| if save_steps is None: | |
| model_logger.on_epoch_end(accelerator, model, epoch_id) | |
| if accelerator.is_main_process and log_file is not None: | |
| log_file.write(f"Epoch {epoch_id+1} completed. Checkpoint saved.\n") | |
| log_file.flush() | |
| model_logger.on_training_end(accelerator, model, save_steps) | |
| if log_file is not None: | |
| log_file.close() | |
| def launch_data_process_task( | |
| accelerator: Accelerator, | |
| dataset: torch.utils.data.Dataset, | |
| model: DiffusionTrainingModule, | |
| model_logger: ModelLogger, | |
| num_workers: int = 8, | |
| args = None, | |
| ): | |
| if args is not None: | |
| num_workers = args.dataset_num_workers | |
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) | |
| model.to(device=accelerator.device) | |
| model, dataloader = accelerator.prepare(model, dataloader) | |
| for data_id, data in enumerate(tqdm(dataloader)): | |
| with accelerator.accumulate(model): | |
| with torch.no_grad(): | |
| folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) | |
| os.makedirs(folder, exist_ok=True) | |
| save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") | |
| data = model(data) | |
| torch.save(data, save_path) | |
| def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator): | |
| if getattr(accelerator.state, "deepspeed_plugin", None) is not None: | |
| ds_config = accelerator.state.deepspeed_plugin.deepspeed_config | |
| if "activation_checkpointing" in ds_config: | |
| import deepspeed | |
| act_config = ds_config["activation_checkpointing"] | |
| deepspeed.checkpointing.configure( | |
| mpu_=None, | |
| partition_activations=act_config.get("partition_activations", False), | |
| checkpoint_in_cpu=act_config.get("cpu_checkpointing", False), | |
| contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False) | |
| ) | |
| else: | |
| print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.") | |