""" This file contains some tools """ import torch import torch.nn as nn import torch.nn.functional as F import os import numpy as np import einops from tqdm import tqdm from torchvision.utils import save_image, make_grid from torchvision.transforms import ToTensor from absl import logging from PIL import Image def set_logger(log_level='info', fname=None): import logging as _logging handler = logging.get_absl_handler() formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s') handler.setFormatter(formatter) logging.set_verbosity(log_level) if fname is not None: handler = _logging.FileHandler(fname) handler.setFormatter(formatter) logging.get_absl_logger().addHandler(handler) def dct2str(dct): return str({k: f'{v:.6g}' for k, v in dct.items()}) def get_nnet(name, **kwargs): if name == 'dimr': from libs.model.dimr import MRModel return MRModel(kwargs["model_args"]) else: raise NotImplementedError(name) def get_optimizer(params, name, adamw_impl=None, **kwargs): if name == 'adam': from torch.optim import Adam return Adam(params, **kwargs) elif name == 'adamw': impl = (adamw_impl or 'bitsandbytes').lower() if impl in ('torch', 'adamw'): from torch.optim import AdamW return AdamW(params, **kwargs) elif impl in ('bitsandbytes', 'adamw8bit'): from bitsandbytes.optim import AdamW8bit return AdamW8bit(params, **kwargs) else: raise ValueError(f'Unsupported AdamW implementation: {impl}') elif name == 'adafactor': from torch.optim import Adafactor return Adafactor(params, **kwargs) else: raise NotImplementedError(name) def customized_lr_scheduler(optimizer, warmup_steps=-1): from torch.optim.lr_scheduler import LambdaLR def fn(step): if warmup_steps > 0: return min(step / warmup_steps, 1) else: return 1 return LambdaLR(optimizer, fn) def get_lr_scheduler(optimizer, name, **kwargs): if name == 'customized': return customized_lr_scheduler(optimizer, **kwargs) elif name == 'cosine': from torch.optim.lr_scheduler import CosineAnnealingLR return CosineAnnealingLR(optimizer, **kwargs) else: raise NotImplementedError(name) def ema(model_dest: nn.Module, model_src: nn.Module, rate): param_dict_src = dict(model_src.named_parameters()) for p_name, p_dest in model_dest.named_parameters(): p_src = param_dict_src[p_name] assert p_src is not p_dest p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) class TrainState(object): def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None): self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.step = step self.nnet = nnet self.nnet_ema = nnet_ema def ema_update(self, rate=0.9999): if self.nnet_ema is not None: ema(self.nnet_ema, self.nnet, rate) def save(self, path): import shutil import time import logging max_retries = 2 retry_delay = 60 # s for attempt in range(max_retries): temp_path = path + f'.tmp_{int(time.time())}' backup_path = path + '.backup' try: if os.path.exists(path): try: if not os.path.exists(os.path.join(path, 'step.pth')): logging.warning(f'Incomplete checkpoint detected at {path}, removing...') shutil.rmtree(path) except Exception as e: logging.warning(f'Error checking checkpoint integrity: {e}') if os.path.exists(temp_path): shutil.rmtree(temp_path) if os.path.exists(backup_path): shutil.rmtree(backup_path) os.makedirs(temp_path, exist_ok=True) torch.save(self.step, os.path.join(temp_path, 'step.pth')) for key, val in self.__dict__.items(): if key != 'step' and val is not None: torch.save(val.state_dict(), os.path.join(temp_path, f'{key}.pth')) if os.path.exists(path): shutil.move(path, backup_path) try: shutil.move(temp_path, path) shutil.rmtree(backup_path) except Exception as e: if os.path.exists(backup_path): shutil.move(backup_path, path) raise else: shutil.move(temp_path, path) logging.info(f'Successfully saved checkpoint to {path}') return except Exception as e: logging.warning(f'Save attempt {attempt + 1}/{max_retries} failed: {e}') for tmp in [temp_path, backup_path]: if os.path.exists(tmp): try: shutil.rmtree(tmp) except: pass if attempt < max_retries - 1: logging.info(f'Retrying in {retry_delay} seconds...') time.sleep(retry_delay) else: logging.error(f'Failed to save checkpoint after {max_retries} attempts') raise def load(self, path): logging.info(f'load from {path}') self.step = torch.load(os.path.join(path, 'step.pth')) for key, val in self.__dict__.items(): if key != 'step' and val is not None: val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu')) def resume(self, ckpt_root, step=None): if not os.path.exists(ckpt_root): return if step is None: ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root))) if not ckpts: return steps = map(lambda x: int(x.split(".")[0]), ckpts) step = max(steps) ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt') logging.info(f'resume from {ckpt_path}') self.load(ckpt_path) def to(self, device): for key, val in self.__dict__.items(): if isinstance(val, nn.Module): val.to(device) def trainable_parameters(nnet): params_decay = [] params_nodecay = [] for name, param in nnet.named_parameters(): if name.endswith(".nodecay_weight") or name.endswith(".nodecay_bias"): params_nodecay.append(param) else: params_decay.append(param) print("params_decay", len(params_decay)) print("params_nodecay", len(params_nodecay)) params = [ {'params': params_decay}, {'params': params_nodecay, 'weight_decay': 0.0} ] return params def initialize_train_state(config, device): nnet = get_nnet(**config.nnet) if hasattr(config, 'pretrained_path') and config.pretrained_path: try: print(f"Loading pretrained weights from {config.pretrained_path}...") pretrained_dict = torch.load(config.pretrained_path, map_location='cpu') model_dict = nnet.state_dict() matched_dict = {} size_mismatch_keys = [] missing_keys = [] for k, v in pretrained_dict.items(): if k in model_dict: if v.shape == model_dict[k].shape: matched_dict[k] = v else: size_mismatch_keys.append(k) print(f" ⚠ Size mismatch: {k}") print(f" pretrained: {v.shape}, current model: {model_dict[k].shape}") for k in model_dict.keys(): if k not in pretrained_dict: missing_keys.append(k) nnet.load_state_dict(matched_dict, strict=False) print(f"\n{'='*60}") print(f"Pretrained weight loading report:") print(f"{'='*60}") print(f"✓ Successfully loaded parameters: {len(matched_dict)}") if size_mismatch_keys: print(f"\n⚠ Size mismatch ({len(size_mismatch_keys)} keys) - skipped, using random init:") for key in size_mismatch_keys[:10]: print(f" • {key}") if len(size_mismatch_keys) > 10: print(f" ... and {len(size_mismatch_keys)-10} more") if missing_keys: print(f"\n⚠ Missing keys ({len(missing_keys)}):") adapter_keys = [k for k in missing_keys if "adapter" in k] other_missing = [k for k in missing_keys if "adapter" not in k] if adapter_keys: print(f" - Adapter-related keys ({len(adapter_keys)}): random init") for key in adapter_keys[:5]: print(f" • {key}") if len(adapter_keys) > 5: print(f" ... and {len(adapter_keys)-5} more") if other_missing: print(f" - Other missing keys ({len(other_missing)}): default init") for key in other_missing[:5]: print(f" • {key}") if len(other_missing) > 5: print(f" ... and {len(other_missing)-5} more") print(f"{'='*60}\n") if hasattr(nnet, 'adapter'): nn.init.xavier_uniform_(nnet.adapter[0].weight) if hasattr(nnet.adapter[0], 'bias') and nnet.adapter[0].bias is not None: nn.init.zeros_(nnet.adapter[0].bias) print("✓ Adapter layer initialized (Xavier uniform)") except FileNotFoundError: print(f"\n❌ Error: pretrained weights file not found '{config.pretrained_path}'") print("Check the path, or comment out config.pretrained_path to train from scratch") raise except Exception as e: print(f"\n❌ Error loading pretrained weights: {str(e)}") import traceback traceback.print_exc() raise else: print("⚠ No pretrained path set; training from scratch (random init)") nnet_ema = get_nnet(**config.nnet) nnet_ema.eval() optimizer = get_optimizer(trainable_parameters(nnet), **config.optimizer) lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler) train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0, nnet=nnet, nnet_ema=nnet_ema) train_state.ema_update(0) train_state.to(device) return train_state def amortize(n_samples, batch_size): k = n_samples // batch_size r = n_samples % batch_size return k * [batch_size] if r == 0 else k * [batch_size] + [r] def sample2dir_with_gt(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, config=None): """ Save generated images, inputs, and ground-truth side by side (order: input, generated, GT). Args: accelerator: accelerate.Accelerator instance path: output directory n_samples: total number of samples mini_batch_size: per-process batch size sample_fn: sampling function returning (generated_samples, gt_images, input_images) unpreprocess_fn: inverse preprocessing function config: config object """ os.makedirs(path, exist_ok=True) idx = 0 batch_size = mini_batch_size * accelerator.num_processes for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir_with_gt'): samples, gt_images, input_images = sample_fn(mini_batch_size, config=config) samples = unpreprocess_fn(samples) gt_images = unpreprocess_fn(gt_images) input_images = unpreprocess_fn(input_images) samples = accelerator.gather(samples.contiguous())[:_batch_size] gt_images = accelerator.gather(gt_images.contiguous())[:_batch_size] input_images = accelerator.gather(input_images.contiguous())[:_batch_size] if accelerator.is_main_process: target_size = 256 for input_img, sample, gt in zip(input_images, samples, gt_images): if input_img.shape[1] != target_size or input_img.shape[2] != target_size: input_img = input_img.unsqueeze(0) input_img = F.interpolate(input_img, size=(target_size, target_size), mode='bilinear', align_corners=False) input_img = input_img.squeeze(0) if sample.shape[1] != target_size or sample.shape[2] != target_size: sample = sample.unsqueeze(0) sample = F.interpolate(sample, size=(target_size, target_size), mode='bilinear', align_corners=False) sample = sample.squeeze(0) if gt.shape[1] != target_size or gt.shape[2] != target_size: gt = gt.unsqueeze(0) gt = F.interpolate(gt, size=(target_size, target_size), mode='bilinear', align_corners=False) gt = gt.squeeze(0) images_triplet = torch.stack([input_img, sample, gt], dim=0) concatenated = make_grid(images_triplet, nrow=3, padding=2, pad_value=1.0) save_image(concatenated, os.path.join(path, f"{idx}.png")) idx += 1 # Global cache to avoid repeated tokenizer.encode calls _tokenizer_cache = {} _tokenizer_cache_lock = None def _get_tokenizer_cache_key(vl_chat_processor, question): """Build cache key for tokenizer output.""" # Unique key from question and processor settings cache_key = ( question, vl_chat_processor.sft_format, vl_chat_processor.system_prompt, id(vl_chat_processor.tokenizer) # tokenizer object id for uniqueness ) return cache_key def _get_or_encode_tokenizer(vl_chat_processor, question, device): """Get or encode tokenizer output (cached).""" global _tokenizer_cache, _tokenizer_cache_lock import threading if _tokenizer_cache_lock is None: _tokenizer_cache_lock = threading.Lock() cache_key = _get_tokenizer_cache_key(vl_chat_processor, question) with _tokenizer_cache_lock: if cache_key in _tokenizer_cache: return _tokenizer_cache[cache_key] sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=[ {"role": "<|User|>", "content": f"\n{question}"}, {"role": "<|Assistant|>", "content": ""}, ], sft_format=vl_chat_processor.sft_format, system_prompt=vl_chat_processor.system_prompt, ) input_ids = vl_chat_processor.tokenizer.encode(sft_format) with _tokenizer_cache_lock: _tokenizer_cache[cache_key] = input_ids return input_ids def get_input_image_embeddings_and_masks( batch_input_images, vl_chat_processor, vl_gpt, device, question="", num_image_tokens=576, output_tokens=None, accelerator=None, cached_input_ids=None ): """ Batch-process input images and obtain token embeddings and masks. Args: batch_input_images: One of: - torch.Tensor: preprocessed tensor [batch_size, 3, H, W] (WebDataset mode) - list of str: image paths (filesystem mode, backward compatible) vl_chat_processor: Janus VLChatProcessor instance vl_gpt: Janus MultiModalityCausalLM. If wrapped by accelerator.prepare(), unwrap first: vl_gpt = accelerator.unwrap_model(vl_gpt) if hasattr(accelerator, 'unwrap_model') else vl_gpt device: torch.device question: optional text prompt (default empty) num_image_tokens: tokens per image (default 576) output_tokens: if set, keep first N tokens (default None = all) accelerator: optional, for rank in error logs cached_input_ids: optional pre-encoded input_ids on CPU; skips tokenizer.encode if set Returns: batch_embeddings: [batch_size, output_tokens or num_image_tokens, hidden_dim] on device batch_attention_masks: [batch_size, output_tokens or num_image_tokens] on device """ batch_embeddings_list = [] batch_attention_masks_list = [] if isinstance(batch_input_images, torch.Tensor): if batch_input_images.device != device: batched_pixel_values = batch_input_images.to(device, non_blocking=True) else: batched_pixel_values = batch_input_images batch_size = batched_pixel_values.shape[0] else: import concurrent.futures def load_image(image_input): """Load one image; supports path strings.""" if isinstance(image_input, str): try: pil_img = Image.open(image_input) pil_img.load() return pil_img.convert('RGB') except Exception as e: rank_info = f"[Rank {accelerator.process_index}] " if accelerator is not None else "" print(f"{rank_info}Warning: failed to load input image {image_input}: {e}") return Image.new('RGB', (384, 384), color='black') else: rank_info = f"[Rank {accelerator.process_index}] " if accelerator is not None else "" print(f"{rank_info}Warning: unsupported type {type(image_input)}") return Image.new('RGB', (384, 384), color='black') if len(batch_input_images) > 0: max_workers = min(len(batch_input_images), os.cpu_count() or 1) if max_workers > 1: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: all_pil_images = list(executor.map(load_image, batch_input_images)) else: all_pil_images = [load_image(path) for path in batch_input_images] else: all_pil_images = [] images_outputs = vl_chat_processor.image_processor(all_pil_images, return_tensors="pt") batched_pixel_values = images_outputs.pixel_values.to(device, non_blocking=True) # [batch_size, 3, H, W] batch_size = len(all_pil_images) if cached_input_ids is not None: input_ids = cached_input_ids else: input_ids = _get_or_encode_tokenizer(vl_chat_processor, question, device) batched_input_ids = torch.tensor([input_ids] * batch_size, dtype=torch.long, device=device) image_token_mask = batched_input_ids == vl_chat_processor.image_id batched_images_seq_mask = image_token_mask batched_images_emb_mask = torch.zeros((batch_size, 1, num_image_tokens), dtype=torch.bool, device=device) batched_images_emb_mask[:, :, :num_image_tokens] = True batched_pixel_values = batched_pixel_values.unsqueeze(1) with torch.no_grad(): inputs_embeds = vl_gpt.prepare_inputs_embeds( input_ids=batched_input_ids, pixel_values=batched_pixel_values, images_seq_mask=batched_images_seq_mask, images_emb_mask=batched_images_emb_mask ) inputs_embeds = inputs_embeds.detach().float() if inputs_embeds.shape[1] == num_image_tokens: batch_embeddings = inputs_embeds # [batch_size, num_image_tokens, hidden_dim] else: num_image_tokens_per_sample = batched_images_seq_mask.sum(dim=1) # [batch_size] if (num_image_tokens_per_sample == num_image_tokens).all(): batch_embeddings_list = [] for i in range(batch_size): image_mask = batched_images_seq_mask[i] # [seq_len] image_embeddings = inputs_embeds[i][image_mask] # [num_image_tokens, hidden_dim] batch_embeddings_list.append(image_embeddings) batch_embeddings = torch.stack(batch_embeddings_list, dim=0) # [batch_size, num_image_tokens, hidden_dim] else: batch_embeddings_list = [] for i in range(batch_size): image_mask = batched_images_seq_mask[i] # [seq_len] image_embeddings = inputs_embeds[i][image_mask] # [actual_tokens, hidden_dim] if image_embeddings.shape[0] > num_image_tokens: image_embeddings = image_embeddings[:num_image_tokens] elif image_embeddings.shape[0] < num_image_tokens: padding = torch.zeros( (num_image_tokens - image_embeddings.shape[0], image_embeddings.shape[1]), device=device, dtype=image_embeddings.dtype ) image_embeddings = torch.cat([image_embeddings, padding], dim=0) batch_embeddings_list.append(image_embeddings) batch_embeddings = torch.stack(batch_embeddings_list, dim=0) # [batch_size, num_image_tokens, hidden_dim] batch_attention_masks = torch.ones( (batch_size, num_image_tokens), device=device, dtype=torch.long ) # [batch_size, num_image_tokens] if output_tokens is not None and output_tokens < num_image_tokens: batch_embeddings = batch_embeddings[:, :output_tokens, :] # [batch_size, output_tokens, hidden_dim] batch_attention_masks = batch_attention_masks[:, :output_tokens] # [batch_size, output_tokens] return batch_embeddings, batch_attention_masks # Visualization helpers def resize_tensor_image(img, target_size, device=None): """Resize a [C, H, W] image tensor to target_size × target_size.""" if device is not None and img.device != device: img = img.to(device) if img.shape[1] != target_size or img.shape[2] != target_size: img = F.interpolate( img.unsqueeze(0), size=(target_size, target_size), mode='bilinear', align_corners=False, ).squeeze(0) return img def build_cross_atten_mask_from_batch_type(batch_type, batch_size, device): """Build cross-attention bool mask from a list of type bytes. Elements equal to ``b"t2i"`` map to ``True`` (skip cross-attention). If *batch_type* is ``None``, returns an all-False tensor of length *batch_size*. """ if batch_type is not None: return torch.tensor( [t == b"t2i" if t is not None else False for t in batch_type], dtype=torch.bool, device=device, ) return torch.zeros(batch_size, dtype=torch.bool, device=device) def build_cross_atten_mask_from_paths(image_paths, device): """Return a bool tensor indicating which images are t2i (skip cross-attention). Images whose filename starts with ``t2i_`` map to ``True``. """ mask = [os.path.basename(p).startswith("t2i_") for p in image_paths] return torch.tensor(mask, dtype=torch.bool, device=device) def load_images_as_latents(image_paths, resolution, autoencoder, device): """Load images from *image_paths*, center-crop, and encode through autoencoder. Returns the sampled latent tensor of shape ``[N, 4, H, W]``. """ from data import center_crop_arr tensors = [] for path in image_paths: pil = Image.open(path).convert("RGB") arr = center_crop_arr(pil, image_size=resolution) arr = (arr / 127.5 - 1.0).astype(np.float32) tensors.append(torch.from_numpy(einops.rearrange(arr, 'h w c -> c h w')).to(device)) stacked = torch.stack(tensors, dim=0) moments = autoencoder(stacked, fn='encode_moments').detach() return autoencoder.sample(moments) def save_vis_grid_and_log( samples_unpreprocessed, input_images_pil, gt_images_pil, sample_dir, step, wandb_module, device, samples_per_group=10, target_size=256, ): """Build [input | gt | generated] image grids, save to disk, and log to wandb. Args: samples_unpreprocessed: list/tensor of generated images [C, H, W] in [0, 1]. input_images_pil: list of PIL input images. gt_images_pil: list of PIL ground-truth images (may be empty). sample_dir: directory to save grid images. step: current training step (used in filenames and wandb step). wandb_module: the wandb module (passed to avoid importing it here). device: torch device for tensor operations. samples_per_group: number of rows per saved grid file. target_size: spatial size each image is resized to before gridding. """ total = len(samples_unpreprocessed) num_groups = (total + samples_per_group - 1) // samples_per_group wandb_images = [] for group_id in range(num_groups): start = group_id * samples_per_group end = min(start + samples_per_group, total) group_imgs = [] for i in range(start, end): if i < len(input_images_pil): group_imgs.append(resize_tensor_image(ToTensor()(input_images_pil[i]), target_size, device)) else: group_imgs.append(torch.zeros((3, target_size, target_size), device=device)) if i < len(gt_images_pil): group_imgs.append(resize_tensor_image(ToTensor()(gt_images_pil[i]), target_size, device)) else: group_imgs.append(torch.zeros((3, target_size, target_size), device=device)) group_imgs.append(resize_tensor_image(samples_unpreprocessed[i], target_size, device)) save_path = os.path.join(sample_dir, f'{step}_{group_id + 1}.png') if group_imgs: grid = make_grid(torch.stack(group_imgs, dim=0), nrow=3, padding=2, pad_value=1.0) save_image(grid, save_path) wandb_images.append(wandb_module.Image(grid)) else: fallback = samples_unpreprocessed[start:end] if len(fallback) > 0: grid = make_grid(fallback, 5) save_image(grid, save_path) wandb_images.append(wandb_module.Image(grid)) if wandb_images: wandb_module.log({'samples': wandb_images}, step=step) def clean_stale_ckpt_files(ckpt_root): """Remove temporary/backup files in *ckpt_root* older than one hour.""" import shutil import time as _time try: for item in os.listdir(ckpt_root): if '.tmp_' not in item and not item.endswith('.backup'): continue tmp_path = os.path.join(ckpt_root, item) try: if os.path.isdir(tmp_path) and _time.time() - os.path.getmtime(tmp_path) > 3600: logging.info(f'Removing stale temporary directory: {tmp_path}') shutil.rmtree(tmp_path) except Exception as e: logging.warning(f'Error cleaning temporary file {tmp_path}: {e}') except Exception as e: logging.warning(f'Error scanning checkpoint directory: {e}')