File size: 6,284 Bytes
fe365dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# Utilities for Falcon Vision
# Model loading and image preprocessing without tokenizer dependency
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from PIL import Image
from typing import Union, List
import os
from .model import AMOE
from .configs import AMOEArgs, amoe_configs
from .image_processor import AMOEImageProcessor
def load_amoe_model(
checkpoint_path: str,
config_name: str = "18-layers-distillation",
device: Union[str, torch.device] = "cuda",
dtype: torch.dtype | None = None,
**kwargs,
) -> tuple[AMOE, AMOEImageProcessor]:
"""
Load a AMOE model from a checkpoint.
Args:
checkpoint_path: Path to the model checkpoint
config_name: Name of the model configuration
device: Device to load the model on
dtype: Optional dtype to cast model weights to (e.g. torch.bfloat16)
Returns:
Tuple of (model, image_processor)
"""
# Get configuration
if config_name in amoe_configs:
args = amoe_configs[config_name]
else:
raise ValueError(f"Unknown config: {config_name}. Available: {list(amoe_configs.keys())}")
# Create model
model = AMOE(args)
# Standard PyTorch checkpoint
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(state_dict)
if dtype is None:
model = model.to(device=device)
else:
model = model.to(device=device, dtype=dtype)
model.eval()
# Create image processor
image_processor = AMOEImageProcessor(patch_size=args.spatial_patch_size, **kwargs)
return model, image_processor
# def convert_torchtitan_checkpoint(
# torchtitan_ckpt_path: str,
# output_path: str,
# config_name: str = "0.25B-1B-a-tall-se-24l16e-route-distillation",
# ):
# """
# Convert a torchtitan checkpoint to standalone format.
#
# This handles the key mapping differences between the torchtitan
# DistillPerceptionTransformerMultiTeacher and FalconVisionEncoder.
# """
# # Load torchtitan checkpoint
# if os.path.isdir(torchtitan_ckpt_path):
# from torch.distributed.checkpoint import load as dcp_load
# config = omni_falcon_perception_configs[config_name]
# config.max_seq_len = 2048
# config.seq_len = 2304 + 5
# config.vocab_size = 65536
# config.eos_id = 31999
# config.dtype = torch.bfloat16
# config.use_grouped_mm = False
# config.use_flex_attn = True
# config.attn_mask_type = "distill_mask"
# config.img_start_id = 31998
# config.img_end_id = 31997
# config.img_id = 31996
# config.eager = True
# config.n_storage_tokens = 4
# config.img_row_sep_id = 31995
# config.vid_start_id = 31994
# config.vid_end_id = 31993
# config.frame_sep_id = 31992
# config.image_mask_token_id = 31991
# config.image_cls_token_id = 31990
# config.image_reg_1_token_id = 31989
# config.image_reg_2_token_id = 31988
# config.image_reg_3_token_id = 31987
# config.image_reg_4_token_id = 31986
# config.cls_weight = 0
# config.patch_weight = 0
# config.storage_weight = 0
# config.pairwise_distance_weight = 0
# config.pairwise_cosine_weight = 0
# config.pairwise_distance_patch_weight = 0
# config.pairwise_cosine_patch_weight = 0
# config.high_res_distillation_weight = 0
# config.teachers = ("siglip2", "dinov3")
# config.teachers_dim = (1152, 1024)
# config.optimizable_teachers = ("siglip2", "dinov3")
# config.average_patch_loss = False
# config.weighted_patch_loss = False
# config.jitter_rope = False
# config.use_phis = False
# config.use_pixel_head = True
#
# # Load model
# model = DistillPerceptionTransformerMultiTeacher(config).to("cuda")
# state_dict = model.state_dict()
# state_dict.pop('freqs_cis', None)
# keys = list(state_dict.keys())
# for k in keys:
# if "coord" in k:
# state_dict.pop(k, None)
# if "size" in k:
# state_dict.pop(k, None)
# if "proj_segm" in k:
# state_dict.pop(k, None)
# if "itok_upsampler" in k:
# state_dict.pop(k, None)
# if "rope_upsampler" in k:
# state_dict.pop(k, None)
#
# dcp_load(state_dict, checkpoint_id=torchtitan_ckpt_path)
# else:
# state_dict = torch.load(torchtitan_ckpt_path, map_location="cpu", weights_only=False)
# if "model" in state_dict:
# state_dict = state_dict["model"]
#
# # Key mapping from torchtitan to standalone
# key_map = {
# "tok_embeddings": None, # Remove text embeddings
# "output": None, # Remove text output
# "pixel_mlp": None, # Remove pixel head
# "proj_segm": None, # Remove segmentation head
# "itok_upsampler": None, # Remove upsampler
# "coord_encoder": None, # Remove coordinate heads
# "coord_decoder": None,
# "size_encoder": None,
# "size_decoder": None,
# "phis_statistics": None, # Remove PHIs statistics
# "rope_upsampler": None, # Remove RoPE upsampler
# }
#
# new_state_dict = {}
# for k, v in state_dict.items():
# # Skip keys that should be removed
# skip = False
# for prefix in key_map.keys():
# if k.startswith(prefix) or k.startswith(f"model.{prefix}"):
# skip = True
# break
# if skip:
# continue
#
# # Remove "model." prefix if present
# new_key = k[6:] if k.startswith("model.") else k
# print(new_key)
# new_state_dict[new_key] = v
#
# # Save converted checkpoint
# torch.save(new_state_dict, output_path)
# print(f"Saved converted checkpoint to {output_path}")
# Feature dimension constants
FEATURE_DIM_DICT = {
"dinov3": 1024,
"siglip2": 1152,
"amoe": 768, # Model dimension
}
PATCH_SIZE = 16
|