# Utilities for Falcon Vision # Model loading and image preprocessing without tokenizer dependency import torch import torch.nn.functional as F from torch import nn import numpy as np from PIL import Image from typing import Union, List import os from .model import AMOE from .configs import AMOEArgs, amoe_configs from .image_processor import AMOEImageProcessor def load_amoe_model( checkpoint_path: str, config_name: str = "18-layers-distillation", device: Union[str, torch.device] = "cuda", dtype: torch.dtype | None = None, **kwargs, ) -> tuple[AMOE, AMOEImageProcessor]: """ Load a AMOE model from a checkpoint. Args: checkpoint_path: Path to the model checkpoint config_name: Name of the model configuration device: Device to load the model on dtype: Optional dtype to cast model weights to (e.g. torch.bfloat16) Returns: Tuple of (model, image_processor) """ # Get configuration if config_name in amoe_configs: args = amoe_configs[config_name] else: raise ValueError(f"Unknown config: {config_name}. Available: {list(amoe_configs.keys())}") # Create model model = AMOE(args) # Standard PyTorch checkpoint state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False) model.load_state_dict(state_dict) if dtype is None: model = model.to(device=device) else: model = model.to(device=device, dtype=dtype) model.eval() # Create image processor image_processor = AMOEImageProcessor(patch_size=args.spatial_patch_size, **kwargs) return model, image_processor # def convert_torchtitan_checkpoint( # torchtitan_ckpt_path: str, # output_path: str, # config_name: str = "0.25B-1B-a-tall-se-24l16e-route-distillation", # ): # """ # Convert a torchtitan checkpoint to standalone format. # # This handles the key mapping differences between the torchtitan # DistillPerceptionTransformerMultiTeacher and FalconVisionEncoder. # """ # # Load torchtitan checkpoint # if os.path.isdir(torchtitan_ckpt_path): # from torch.distributed.checkpoint import load as dcp_load # config = omni_falcon_perception_configs[config_name] # config.max_seq_len = 2048 # config.seq_len = 2304 + 5 # config.vocab_size = 65536 # config.eos_id = 31999 # config.dtype = torch.bfloat16 # config.use_grouped_mm = False # config.use_flex_attn = True # config.attn_mask_type = "distill_mask" # config.img_start_id = 31998 # config.img_end_id = 31997 # config.img_id = 31996 # config.eager = True # config.n_storage_tokens = 4 # config.img_row_sep_id = 31995 # config.vid_start_id = 31994 # config.vid_end_id = 31993 # config.frame_sep_id = 31992 # config.image_mask_token_id = 31991 # config.image_cls_token_id = 31990 # config.image_reg_1_token_id = 31989 # config.image_reg_2_token_id = 31988 # config.image_reg_3_token_id = 31987 # config.image_reg_4_token_id = 31986 # config.cls_weight = 0 # config.patch_weight = 0 # config.storage_weight = 0 # config.pairwise_distance_weight = 0 # config.pairwise_cosine_weight = 0 # config.pairwise_distance_patch_weight = 0 # config.pairwise_cosine_patch_weight = 0 # config.high_res_distillation_weight = 0 # config.teachers = ("siglip2", "dinov3") # config.teachers_dim = (1152, 1024) # config.optimizable_teachers = ("siglip2", "dinov3") # config.average_patch_loss = False # config.weighted_patch_loss = False # config.jitter_rope = False # config.use_phis = False # config.use_pixel_head = True # # # Load model # model = DistillPerceptionTransformerMultiTeacher(config).to("cuda") # state_dict = model.state_dict() # state_dict.pop('freqs_cis', None) # keys = list(state_dict.keys()) # for k in keys: # if "coord" in k: # state_dict.pop(k, None) # if "size" in k: # state_dict.pop(k, None) # if "proj_segm" in k: # state_dict.pop(k, None) # if "itok_upsampler" in k: # state_dict.pop(k, None) # if "rope_upsampler" in k: # state_dict.pop(k, None) # # dcp_load(state_dict, checkpoint_id=torchtitan_ckpt_path) # else: # state_dict = torch.load(torchtitan_ckpt_path, map_location="cpu", weights_only=False) # if "model" in state_dict: # state_dict = state_dict["model"] # # # Key mapping from torchtitan to standalone # key_map = { # "tok_embeddings": None, # Remove text embeddings # "output": None, # Remove text output # "pixel_mlp": None, # Remove pixel head # "proj_segm": None, # Remove segmentation head # "itok_upsampler": None, # Remove upsampler # "coord_encoder": None, # Remove coordinate heads # "coord_decoder": None, # "size_encoder": None, # "size_decoder": None, # "phis_statistics": None, # Remove PHIs statistics # "rope_upsampler": None, # Remove RoPE upsampler # } # # new_state_dict = {} # for k, v in state_dict.items(): # # Skip keys that should be removed # skip = False # for prefix in key_map.keys(): # if k.startswith(prefix) or k.startswith(f"model.{prefix}"): # skip = True # break # if skip: # continue # # # Remove "model." prefix if present # new_key = k[6:] if k.startswith("model.") else k # print(new_key) # new_state_dict[new_key] = v # # # Save converted checkpoint # torch.save(new_state_dict, output_path) # print(f"Saved converted checkpoint to {output_path}") # Feature dimension constants FEATURE_DIM_DICT = { "dinov3": 1024, "siglip2": 1152, "amoe": 768, # Model dimension } PATCH_SIZE = 16