FLowInOne_demo / utils.py
oedevs's picture
upload file
0fd26a8
"""
This file contains some tools
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import einops
from tqdm import tqdm
from torchvision.utils import save_image, make_grid
from torchvision.transforms import ToTensor
from absl import logging
from PIL import Image
def set_logger(log_level='info', fname=None):
import logging as _logging
handler = logging.get_absl_handler()
formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
handler.setFormatter(formatter)
logging.set_verbosity(log_level)
if fname is not None:
handler = _logging.FileHandler(fname)
handler.setFormatter(formatter)
logging.get_absl_logger().addHandler(handler)
def dct2str(dct):
return str({k: f'{v:.6g}' for k, v in dct.items()})
def get_nnet(name, **kwargs):
if name == 'dimr':
from libs.model.dimr import MRModel
return MRModel(kwargs["model_args"])
else:
raise NotImplementedError(name)
def get_optimizer(params, name, adamw_impl=None, **kwargs):
if name == 'adam':
from torch.optim import Adam
return Adam(params, **kwargs)
elif name == 'adamw':
impl = (adamw_impl or 'bitsandbytes').lower()
if impl in ('torch', 'adamw'):
from torch.optim import AdamW
return AdamW(params, **kwargs)
elif impl in ('bitsandbytes', 'adamw8bit'):
from bitsandbytes.optim import AdamW8bit
return AdamW8bit(params, **kwargs)
else:
raise ValueError(f'Unsupported AdamW implementation: {impl}')
elif name == 'adafactor':
from torch.optim import Adafactor
return Adafactor(params, **kwargs)
else:
raise NotImplementedError(name)
def customized_lr_scheduler(optimizer, warmup_steps=-1):
from torch.optim.lr_scheduler import LambdaLR
def fn(step):
if warmup_steps > 0:
return min(step / warmup_steps, 1)
else:
return 1
return LambdaLR(optimizer, fn)
def get_lr_scheduler(optimizer, name, **kwargs):
if name == 'customized':
return customized_lr_scheduler(optimizer, **kwargs)
elif name == 'cosine':
from torch.optim.lr_scheduler import CosineAnnealingLR
return CosineAnnealingLR(optimizer, **kwargs)
else:
raise NotImplementedError(name)
def ema(model_dest: nn.Module, model_src: nn.Module, rate):
param_dict_src = dict(model_src.named_parameters())
for p_name, p_dest in model_dest.named_parameters():
p_src = param_dict_src[p_name]
assert p_src is not p_dest
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
class TrainState(object):
def __init__(self, optimizer, lr_scheduler, step, nnet=None, nnet_ema=None):
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.step = step
self.nnet = nnet
self.nnet_ema = nnet_ema
def ema_update(self, rate=0.9999):
if self.nnet_ema is not None:
ema(self.nnet_ema, self.nnet, rate)
def save(self, path):
import shutil
import time
import logging
max_retries = 2
retry_delay = 60 # s
for attempt in range(max_retries):
temp_path = path + f'.tmp_{int(time.time())}'
backup_path = path + '.backup'
try:
if os.path.exists(path):
try:
if not os.path.exists(os.path.join(path, 'step.pth')):
logging.warning(f'Incomplete checkpoint detected at {path}, removing...')
shutil.rmtree(path)
except Exception as e:
logging.warning(f'Error checking checkpoint integrity: {e}')
if os.path.exists(temp_path):
shutil.rmtree(temp_path)
if os.path.exists(backup_path):
shutil.rmtree(backup_path)
os.makedirs(temp_path, exist_ok=True)
torch.save(self.step, os.path.join(temp_path, 'step.pth'))
for key, val in self.__dict__.items():
if key != 'step' and val is not None:
torch.save(val.state_dict(), os.path.join(temp_path, f'{key}.pth'))
if os.path.exists(path):
shutil.move(path, backup_path)
try:
shutil.move(temp_path, path)
shutil.rmtree(backup_path)
except Exception as e:
if os.path.exists(backup_path):
shutil.move(backup_path, path)
raise
else:
shutil.move(temp_path, path)
logging.info(f'Successfully saved checkpoint to {path}')
return
except Exception as e:
logging.warning(f'Save attempt {attempt + 1}/{max_retries} failed: {e}')
for tmp in [temp_path, backup_path]:
if os.path.exists(tmp):
try:
shutil.rmtree(tmp)
except:
pass
if attempt < max_retries - 1:
logging.info(f'Retrying in {retry_delay} seconds...')
time.sleep(retry_delay)
else:
logging.error(f'Failed to save checkpoint after {max_retries} attempts')
raise
def load(self, path):
logging.info(f'load from {path}')
self.step = torch.load(os.path.join(path, 'step.pth'))
for key, val in self.__dict__.items():
if key != 'step' and val is not None:
val.load_state_dict(torch.load(os.path.join(path, f'{key}.pth'), map_location='cpu'))
def resume(self, ckpt_root, step=None):
if not os.path.exists(ckpt_root):
return
if step is None:
ckpts = list(filter(lambda x: '.ckpt' in x, os.listdir(ckpt_root)))
if not ckpts:
return
steps = map(lambda x: int(x.split(".")[0]), ckpts)
step = max(steps)
ckpt_path = os.path.join(ckpt_root, f'{step}.ckpt')
logging.info(f'resume from {ckpt_path}')
self.load(ckpt_path)
def to(self, device):
for key, val in self.__dict__.items():
if isinstance(val, nn.Module):
val.to(device)
def trainable_parameters(nnet):
params_decay = []
params_nodecay = []
for name, param in nnet.named_parameters():
if name.endswith(".nodecay_weight") or name.endswith(".nodecay_bias"):
params_nodecay.append(param)
else:
params_decay.append(param)
print("params_decay", len(params_decay))
print("params_nodecay", len(params_nodecay))
params = [
{'params': params_decay},
{'params': params_nodecay, 'weight_decay': 0.0}
]
return params
def initialize_train_state(config, device):
nnet = get_nnet(**config.nnet)
if hasattr(config, 'pretrained_path') and config.pretrained_path:
try:
print(f"Loading pretrained weights from {config.pretrained_path}...")
pretrained_dict = torch.load(config.pretrained_path, map_location='cpu')
model_dict = nnet.state_dict()
matched_dict = {}
size_mismatch_keys = []
missing_keys = []
for k, v in pretrained_dict.items():
if k in model_dict:
if v.shape == model_dict[k].shape:
matched_dict[k] = v
else:
size_mismatch_keys.append(k)
print(f" ⚠ Size mismatch: {k}")
print(f" pretrained: {v.shape}, current model: {model_dict[k].shape}")
for k in model_dict.keys():
if k not in pretrained_dict:
missing_keys.append(k)
nnet.load_state_dict(matched_dict, strict=False)
print(f"\n{'='*60}")
print(f"Pretrained weight loading report:")
print(f"{'='*60}")
print(f"✓ Successfully loaded parameters: {len(matched_dict)}")
if size_mismatch_keys:
print(f"\n⚠ Size mismatch ({len(size_mismatch_keys)} keys) - skipped, using random init:")
for key in size_mismatch_keys[:10]:
print(f" • {key}")
if len(size_mismatch_keys) > 10:
print(f" ... and {len(size_mismatch_keys)-10} more")
if missing_keys:
print(f"\n⚠ Missing keys ({len(missing_keys)}):")
adapter_keys = [k for k in missing_keys if "adapter" in k]
other_missing = [k for k in missing_keys if "adapter" not in k]
if adapter_keys:
print(f" - Adapter-related keys ({len(adapter_keys)}): random init")
for key in adapter_keys[:5]:
print(f" • {key}")
if len(adapter_keys) > 5:
print(f" ... and {len(adapter_keys)-5} more")
if other_missing:
print(f" - Other missing keys ({len(other_missing)}): default init")
for key in other_missing[:5]:
print(f" • {key}")
if len(other_missing) > 5:
print(f" ... and {len(other_missing)-5} more")
print(f"{'='*60}\n")
if hasattr(nnet, 'adapter'):
nn.init.xavier_uniform_(nnet.adapter[0].weight)
if hasattr(nnet.adapter[0], 'bias') and nnet.adapter[0].bias is not None:
nn.init.zeros_(nnet.adapter[0].bias)
print("✓ Adapter layer initialized (Xavier uniform)")
except FileNotFoundError:
print(f"\n❌ Error: pretrained weights file not found '{config.pretrained_path}'")
print("Check the path, or comment out config.pretrained_path to train from scratch")
raise
except Exception as e:
print(f"\n❌ Error loading pretrained weights: {str(e)}")
import traceback
traceback.print_exc()
raise
else:
print("⚠ No pretrained path set; training from scratch (random init)")
nnet_ema = get_nnet(**config.nnet)
nnet_ema.eval()
optimizer = get_optimizer(trainable_parameters(nnet), **config.optimizer)
lr_scheduler = get_lr_scheduler(optimizer, **config.lr_scheduler)
train_state = TrainState(optimizer=optimizer, lr_scheduler=lr_scheduler, step=0,
nnet=nnet, nnet_ema=nnet_ema)
train_state.ema_update(0)
train_state.to(device)
return train_state
def amortize(n_samples, batch_size):
k = n_samples // batch_size
r = n_samples % batch_size
return k * [batch_size] if r == 0 else k * [batch_size] + [r]
def sample2dir_with_gt(accelerator, path, n_samples, mini_batch_size, sample_fn, unpreprocess_fn=None, config=None):
"""
Save generated images, inputs, and ground-truth side by side (order: input, generated, GT).
Args:
accelerator: accelerate.Accelerator instance
path: output directory
n_samples: total number of samples
mini_batch_size: per-process batch size
sample_fn: sampling function returning (generated_samples, gt_images, input_images)
unpreprocess_fn: inverse preprocessing function
config: config object
"""
os.makedirs(path, exist_ok=True)
idx = 0
batch_size = mini_batch_size * accelerator.num_processes
for _batch_size in tqdm(amortize(n_samples, batch_size), disable=not accelerator.is_main_process, desc='sample2dir_with_gt'):
samples, gt_images, input_images = sample_fn(mini_batch_size, config=config)
samples = unpreprocess_fn(samples)
gt_images = unpreprocess_fn(gt_images)
input_images = unpreprocess_fn(input_images)
samples = accelerator.gather(samples.contiguous())[:_batch_size]
gt_images = accelerator.gather(gt_images.contiguous())[:_batch_size]
input_images = accelerator.gather(input_images.contiguous())[:_batch_size]
if accelerator.is_main_process:
target_size = 256
for input_img, sample, gt in zip(input_images, samples, gt_images):
if input_img.shape[1] != target_size or input_img.shape[2] != target_size:
input_img = input_img.unsqueeze(0)
input_img = F.interpolate(input_img, size=(target_size, target_size), mode='bilinear', align_corners=False)
input_img = input_img.squeeze(0)
if sample.shape[1] != target_size or sample.shape[2] != target_size:
sample = sample.unsqueeze(0)
sample = F.interpolate(sample, size=(target_size, target_size), mode='bilinear', align_corners=False)
sample = sample.squeeze(0)
if gt.shape[1] != target_size or gt.shape[2] != target_size:
gt = gt.unsqueeze(0)
gt = F.interpolate(gt, size=(target_size, target_size), mode='bilinear', align_corners=False)
gt = gt.squeeze(0)
images_triplet = torch.stack([input_img, sample, gt], dim=0)
concatenated = make_grid(images_triplet, nrow=3, padding=2, pad_value=1.0)
save_image(concatenated, os.path.join(path, f"{idx}.png"))
idx += 1
# Global cache to avoid repeated tokenizer.encode calls
_tokenizer_cache = {}
_tokenizer_cache_lock = None
def _get_tokenizer_cache_key(vl_chat_processor, question):
"""Build cache key for tokenizer output."""
# Unique key from question and processor settings
cache_key = (
question,
vl_chat_processor.sft_format,
vl_chat_processor.system_prompt,
id(vl_chat_processor.tokenizer) # tokenizer object id for uniqueness
)
return cache_key
def _get_or_encode_tokenizer(vl_chat_processor, question, device):
"""Get or encode tokenizer output (cached)."""
global _tokenizer_cache, _tokenizer_cache_lock
import threading
if _tokenizer_cache_lock is None:
_tokenizer_cache_lock = threading.Lock()
cache_key = _get_tokenizer_cache_key(vl_chat_processor, question)
with _tokenizer_cache_lock:
if cache_key in _tokenizer_cache:
return _tokenizer_cache[cache_key]
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=[
{"role": "<|User|>", "content": f"<image_placeholder>\n{question}"},
{"role": "<|Assistant|>", "content": ""},
],
sft_format=vl_chat_processor.sft_format,
system_prompt=vl_chat_processor.system_prompt,
)
input_ids = vl_chat_processor.tokenizer.encode(sft_format)
with _tokenizer_cache_lock:
_tokenizer_cache[cache_key] = input_ids
return input_ids
def get_input_image_embeddings_and_masks(
batch_input_images,
vl_chat_processor,
vl_gpt,
device,
question="",
num_image_tokens=576,
output_tokens=None,
accelerator=None,
cached_input_ids=None
):
"""
Batch-process input images and obtain token embeddings and masks.
Args:
batch_input_images: One of:
- torch.Tensor: preprocessed tensor [batch_size, 3, H, W] (WebDataset mode)
- list of str: image paths (filesystem mode, backward compatible)
vl_chat_processor: Janus VLChatProcessor instance
vl_gpt: Janus MultiModalityCausalLM. If wrapped by accelerator.prepare(), unwrap first:
vl_gpt = accelerator.unwrap_model(vl_gpt) if hasattr(accelerator, 'unwrap_model') else vl_gpt
device: torch.device
question: optional text prompt (default empty)
num_image_tokens: tokens per image (default 576)
output_tokens: if set, keep first N tokens (default None = all)
accelerator: optional, for rank in error logs
cached_input_ids: optional pre-encoded input_ids on CPU; skips tokenizer.encode if set
Returns:
batch_embeddings: [batch_size, output_tokens or num_image_tokens, hidden_dim] on device
batch_attention_masks: [batch_size, output_tokens or num_image_tokens] on device
"""
batch_embeddings_list = []
batch_attention_masks_list = []
if isinstance(batch_input_images, torch.Tensor):
if batch_input_images.device != device:
batched_pixel_values = batch_input_images.to(device, non_blocking=True)
else:
batched_pixel_values = batch_input_images
batch_size = batched_pixel_values.shape[0]
else:
import concurrent.futures
def load_image(image_input):
"""Load one image; supports path strings."""
if isinstance(image_input, str):
try:
pil_img = Image.open(image_input)
pil_img.load()
return pil_img.convert('RGB')
except Exception as e:
rank_info = f"[Rank {accelerator.process_index}] " if accelerator is not None else ""
print(f"{rank_info}Warning: failed to load input image {image_input}: {e}")
return Image.new('RGB', (384, 384), color='black')
else:
rank_info = f"[Rank {accelerator.process_index}] " if accelerator is not None else ""
print(f"{rank_info}Warning: unsupported type {type(image_input)}")
return Image.new('RGB', (384, 384), color='black')
if len(batch_input_images) > 0:
max_workers = min(len(batch_input_images), os.cpu_count() or 1)
if max_workers > 1:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
all_pil_images = list(executor.map(load_image, batch_input_images))
else:
all_pil_images = [load_image(path) for path in batch_input_images]
else:
all_pil_images = []
images_outputs = vl_chat_processor.image_processor(all_pil_images, return_tensors="pt")
batched_pixel_values = images_outputs.pixel_values.to(device, non_blocking=True) # [batch_size, 3, H, W]
batch_size = len(all_pil_images)
if cached_input_ids is not None:
input_ids = cached_input_ids
else:
input_ids = _get_or_encode_tokenizer(vl_chat_processor, question, device)
batched_input_ids = torch.tensor([input_ids] * batch_size, dtype=torch.long, device=device)
image_token_mask = batched_input_ids == vl_chat_processor.image_id
batched_images_seq_mask = image_token_mask
batched_images_emb_mask = torch.zeros((batch_size, 1, num_image_tokens), dtype=torch.bool, device=device)
batched_images_emb_mask[:, :, :num_image_tokens] = True
batched_pixel_values = batched_pixel_values.unsqueeze(1)
with torch.no_grad():
inputs_embeds = vl_gpt.prepare_inputs_embeds(
input_ids=batched_input_ids,
pixel_values=batched_pixel_values,
images_seq_mask=batched_images_seq_mask,
images_emb_mask=batched_images_emb_mask
)
inputs_embeds = inputs_embeds.detach().float()
if inputs_embeds.shape[1] == num_image_tokens:
batch_embeddings = inputs_embeds # [batch_size, num_image_tokens, hidden_dim]
else:
num_image_tokens_per_sample = batched_images_seq_mask.sum(dim=1) # [batch_size]
if (num_image_tokens_per_sample == num_image_tokens).all():
batch_embeddings_list = []
for i in range(batch_size):
image_mask = batched_images_seq_mask[i] # [seq_len]
image_embeddings = inputs_embeds[i][image_mask] # [num_image_tokens, hidden_dim]
batch_embeddings_list.append(image_embeddings)
batch_embeddings = torch.stack(batch_embeddings_list, dim=0) # [batch_size, num_image_tokens, hidden_dim]
else:
batch_embeddings_list = []
for i in range(batch_size):
image_mask = batched_images_seq_mask[i] # [seq_len]
image_embeddings = inputs_embeds[i][image_mask] # [actual_tokens, hidden_dim]
if image_embeddings.shape[0] > num_image_tokens:
image_embeddings = image_embeddings[:num_image_tokens]
elif image_embeddings.shape[0] < num_image_tokens:
padding = torch.zeros(
(num_image_tokens - image_embeddings.shape[0], image_embeddings.shape[1]),
device=device,
dtype=image_embeddings.dtype
)
image_embeddings = torch.cat([image_embeddings, padding], dim=0)
batch_embeddings_list.append(image_embeddings)
batch_embeddings = torch.stack(batch_embeddings_list, dim=0) # [batch_size, num_image_tokens, hidden_dim]
batch_attention_masks = torch.ones(
(batch_size, num_image_tokens),
device=device,
dtype=torch.long
) # [batch_size, num_image_tokens]
if output_tokens is not None and output_tokens < num_image_tokens:
batch_embeddings = batch_embeddings[:, :output_tokens, :] # [batch_size, output_tokens, hidden_dim]
batch_attention_masks = batch_attention_masks[:, :output_tokens] # [batch_size, output_tokens]
return batch_embeddings, batch_attention_masks
# Visualization helpers
def resize_tensor_image(img, target_size, device=None):
"""Resize a [C, H, W] image tensor to target_size × target_size."""
if device is not None and img.device != device:
img = img.to(device)
if img.shape[1] != target_size or img.shape[2] != target_size:
img = F.interpolate(
img.unsqueeze(0),
size=(target_size, target_size),
mode='bilinear',
align_corners=False,
).squeeze(0)
return img
def build_cross_atten_mask_from_batch_type(batch_type, batch_size, device):
"""Build cross-attention bool mask from a list of type bytes.
Elements equal to ``b"t2i"`` map to ``True`` (skip cross-attention).
If *batch_type* is ``None``, returns an all-False tensor of length *batch_size*.
"""
if batch_type is not None:
return torch.tensor(
[t == b"t2i" if t is not None else False for t in batch_type],
dtype=torch.bool,
device=device,
)
return torch.zeros(batch_size, dtype=torch.bool, device=device)
def build_cross_atten_mask_from_paths(image_paths, device):
"""Return a bool tensor indicating which images are t2i (skip cross-attention).
Images whose filename starts with ``t2i_`` map to ``True``.
"""
mask = [os.path.basename(p).startswith("t2i_") for p in image_paths]
return torch.tensor(mask, dtype=torch.bool, device=device)
def load_images_as_latents(image_paths, resolution, autoencoder, device):
"""Load images from *image_paths*, center-crop, and encode through autoencoder.
Returns the sampled latent tensor of shape ``[N, 4, H, W]``.
"""
from data import center_crop_arr
tensors = []
for path in image_paths:
pil = Image.open(path).convert("RGB")
arr = center_crop_arr(pil, image_size=resolution)
arr = (arr / 127.5 - 1.0).astype(np.float32)
tensors.append(torch.from_numpy(einops.rearrange(arr, 'h w c -> c h w')).to(device))
stacked = torch.stack(tensors, dim=0)
moments = autoencoder(stacked, fn='encode_moments').detach()
return autoencoder.sample(moments)
def save_vis_grid_and_log(
samples_unpreprocessed,
input_images_pil,
gt_images_pil,
sample_dir,
step,
wandb_module,
device,
samples_per_group=10,
target_size=256,
):
"""Build [input | gt | generated] image grids, save to disk, and log to wandb.
Args:
samples_unpreprocessed: list/tensor of generated images [C, H, W] in [0, 1].
input_images_pil: list of PIL input images.
gt_images_pil: list of PIL ground-truth images (may be empty).
sample_dir: directory to save grid images.
step: current training step (used in filenames and wandb step).
wandb_module: the wandb module (passed to avoid importing it here).
device: torch device for tensor operations.
samples_per_group: number of rows per saved grid file.
target_size: spatial size each image is resized to before gridding.
"""
total = len(samples_unpreprocessed)
num_groups = (total + samples_per_group - 1) // samples_per_group
wandb_images = []
for group_id in range(num_groups):
start = group_id * samples_per_group
end = min(start + samples_per_group, total)
group_imgs = []
for i in range(start, end):
if i < len(input_images_pil):
group_imgs.append(resize_tensor_image(ToTensor()(input_images_pil[i]), target_size, device))
else:
group_imgs.append(torch.zeros((3, target_size, target_size), device=device))
if i < len(gt_images_pil):
group_imgs.append(resize_tensor_image(ToTensor()(gt_images_pil[i]), target_size, device))
else:
group_imgs.append(torch.zeros((3, target_size, target_size), device=device))
group_imgs.append(resize_tensor_image(samples_unpreprocessed[i], target_size, device))
save_path = os.path.join(sample_dir, f'{step}_{group_id + 1}.png')
if group_imgs:
grid = make_grid(torch.stack(group_imgs, dim=0), nrow=3, padding=2, pad_value=1.0)
save_image(grid, save_path)
wandb_images.append(wandb_module.Image(grid))
else:
fallback = samples_unpreprocessed[start:end]
if len(fallback) > 0:
grid = make_grid(fallback, 5)
save_image(grid, save_path)
wandb_images.append(wandb_module.Image(grid))
if wandb_images:
wandb_module.log({'samples': wandb_images}, step=step)
def clean_stale_ckpt_files(ckpt_root):
"""Remove temporary/backup files in *ckpt_root* older than one hour."""
import shutil
import time as _time
try:
for item in os.listdir(ckpt_root):
if '.tmp_' not in item and not item.endswith('.backup'):
continue
tmp_path = os.path.join(ckpt_root, item)
try:
if os.path.isdir(tmp_path) and _time.time() - os.path.getmtime(tmp_path) > 3600:
logging.info(f'Removing stale temporary directory: {tmp_path}')
shutil.rmtree(tmp_path)
except Exception as e:
logging.warning(f'Error cleaning temporary file {tmp_path}: {e}')
except Exception as e:
logging.warning(f'Error scanning checkpoint directory: {e}')