amoe / utils.py
SofianChay's picture
Initial upload of AMOE MoE model with custom code
fe365dd verified
# Utilities for Falcon Vision
# Model loading and image preprocessing without tokenizer dependency
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from PIL import Image
from typing import Union, List
import os
from .model import AMOE
from .configs import AMOEArgs, amoe_configs
from .image_processor import AMOEImageProcessor
def load_amoe_model(
checkpoint_path: str,
config_name: str = "18-layers-distillation",
device: Union[str, torch.device] = "cuda",
dtype: torch.dtype | None = None,
**kwargs,
) -> tuple[AMOE, AMOEImageProcessor]:
"""
Load a AMOE model from a checkpoint.
Args:
checkpoint_path: Path to the model checkpoint
config_name: Name of the model configuration
device: Device to load the model on
dtype: Optional dtype to cast model weights to (e.g. torch.bfloat16)
Returns:
Tuple of (model, image_processor)
"""
# Get configuration
if config_name in amoe_configs:
args = amoe_configs[config_name]
else:
raise ValueError(f"Unknown config: {config_name}. Available: {list(amoe_configs.keys())}")
# Create model
model = AMOE(args)
# Standard PyTorch checkpoint
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(state_dict)
if dtype is None:
model = model.to(device=device)
else:
model = model.to(device=device, dtype=dtype)
model.eval()
# Create image processor
image_processor = AMOEImageProcessor(patch_size=args.spatial_patch_size, **kwargs)
return model, image_processor
# def convert_torchtitan_checkpoint(
# torchtitan_ckpt_path: str,
# output_path: str,
# config_name: str = "0.25B-1B-a-tall-se-24l16e-route-distillation",
# ):
# """
# Convert a torchtitan checkpoint to standalone format.
#
# This handles the key mapping differences between the torchtitan
# DistillPerceptionTransformerMultiTeacher and FalconVisionEncoder.
# """
# # Load torchtitan checkpoint
# if os.path.isdir(torchtitan_ckpt_path):
# from torch.distributed.checkpoint import load as dcp_load
# config = omni_falcon_perception_configs[config_name]
# config.max_seq_len = 2048
# config.seq_len = 2304 + 5
# config.vocab_size = 65536
# config.eos_id = 31999
# config.dtype = torch.bfloat16
# config.use_grouped_mm = False
# config.use_flex_attn = True
# config.attn_mask_type = "distill_mask"
# config.img_start_id = 31998
# config.img_end_id = 31997
# config.img_id = 31996
# config.eager = True
# config.n_storage_tokens = 4
# config.img_row_sep_id = 31995
# config.vid_start_id = 31994
# config.vid_end_id = 31993
# config.frame_sep_id = 31992
# config.image_mask_token_id = 31991
# config.image_cls_token_id = 31990
# config.image_reg_1_token_id = 31989
# config.image_reg_2_token_id = 31988
# config.image_reg_3_token_id = 31987
# config.image_reg_4_token_id = 31986
# config.cls_weight = 0
# config.patch_weight = 0
# config.storage_weight = 0
# config.pairwise_distance_weight = 0
# config.pairwise_cosine_weight = 0
# config.pairwise_distance_patch_weight = 0
# config.pairwise_cosine_patch_weight = 0
# config.high_res_distillation_weight = 0
# config.teachers = ("siglip2", "dinov3")
# config.teachers_dim = (1152, 1024)
# config.optimizable_teachers = ("siglip2", "dinov3")
# config.average_patch_loss = False
# config.weighted_patch_loss = False
# config.jitter_rope = False
# config.use_phis = False
# config.use_pixel_head = True
#
# # Load model
# model = DistillPerceptionTransformerMultiTeacher(config).to("cuda")
# state_dict = model.state_dict()
# state_dict.pop('freqs_cis', None)
# keys = list(state_dict.keys())
# for k in keys:
# if "coord" in k:
# state_dict.pop(k, None)
# if "size" in k:
# state_dict.pop(k, None)
# if "proj_segm" in k:
# state_dict.pop(k, None)
# if "itok_upsampler" in k:
# state_dict.pop(k, None)
# if "rope_upsampler" in k:
# state_dict.pop(k, None)
#
# dcp_load(state_dict, checkpoint_id=torchtitan_ckpt_path)
# else:
# state_dict = torch.load(torchtitan_ckpt_path, map_location="cpu", weights_only=False)
# if "model" in state_dict:
# state_dict = state_dict["model"]
#
# # Key mapping from torchtitan to standalone
# key_map = {
# "tok_embeddings": None, # Remove text embeddings
# "output": None, # Remove text output
# "pixel_mlp": None, # Remove pixel head
# "proj_segm": None, # Remove segmentation head
# "itok_upsampler": None, # Remove upsampler
# "coord_encoder": None, # Remove coordinate heads
# "coord_decoder": None,
# "size_encoder": None,
# "size_decoder": None,
# "phis_statistics": None, # Remove PHIs statistics
# "rope_upsampler": None, # Remove RoPE upsampler
# }
#
# new_state_dict = {}
# for k, v in state_dict.items():
# # Skip keys that should be removed
# skip = False
# for prefix in key_map.keys():
# if k.startswith(prefix) or k.startswith(f"model.{prefix}"):
# skip = True
# break
# if skip:
# continue
#
# # Remove "model." prefix if present
# new_key = k[6:] if k.startswith("model.") else k
# print(new_key)
# new_state_dict[new_key] = v
#
# # Save converted checkpoint
# torch.save(new_state_dict, output_path)
# print(f"Saved converted checkpoint to {output_path}")
# Feature dimension constants
FEATURE_DIM_DICT = {
"dinov3": 1024,
"siglip2": 1152,
"amoe": 768, # Model dimension
}
PATCH_SIZE = 16