|
|
import safetensors
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Callable, List
|
|
|
|
|
|
from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class VisionBlock:
|
|
|
ln1: LayerNormWeights
|
|
|
attn: AttentionWeights
|
|
|
ln2: LayerNormWeights
|
|
|
mlp: MLPWeights
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class VisionModel:
|
|
|
patch_emb: LinearWeights
|
|
|
pos_emb: torch.Tensor
|
|
|
blocks: List[VisionBlock]
|
|
|
post_ln: LayerNormWeights
|
|
|
proj_mlp: MLPWeights
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TextBlock:
|
|
|
ln: LayerNormWeights
|
|
|
attn: AttentionWeights
|
|
|
mlp: MLPWeights
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TextModel:
|
|
|
wte: torch.Tensor
|
|
|
blocks: List[TextBlock]
|
|
|
post_ln: LayerNormWeights
|
|
|
lm_head: LinearWeights
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class RegionModel:
|
|
|
coord_features: torch.Tensor
|
|
|
coord_encoder: LinearWeights
|
|
|
coord_decoder: MLPWeights
|
|
|
size_features: torch.Tensor
|
|
|
size_encoder: LinearWeights
|
|
|
size_decoder: MLPWeights
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class MoondreamModel:
|
|
|
vision: VisionModel
|
|
|
text: TextModel
|
|
|
region: RegionModel
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def safetensors_open(safetensors_file: str):
|
|
|
"""
|
|
|
Simplify interfacing with safetensors files. Eliminates the need to ignore
|
|
|
type errors when using the `safe_open` function.
|
|
|
"""
|
|
|
with safetensors.safe_open(
|
|
|
safetensors_file, framework="pt"
|
|
|
) as st:
|
|
|
|
|
|
def get_tensor(name: str) -> torch.Tensor:
|
|
|
return st.get_tensor(name)
|
|
|
|
|
|
def get_keys() -> List[str]:
|
|
|
return st.keys()
|
|
|
|
|
|
get_tensor.keys = get_keys
|
|
|
|
|
|
yield get_tensor
|
|
|
|
|
|
|
|
|
def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
|
|
|
"""Internal function to load weights using a tensor getter function."""
|
|
|
model = model.to(dtype=torch.float16)
|
|
|
|
|
|
|
|
|
model.vision["patch_emb"].weight.data.copy_(
|
|
|
get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.weight")
|
|
|
)
|
|
|
model.vision["patch_emb"].bias.data.copy_(
|
|
|
get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.bias")
|
|
|
)
|
|
|
model.vision.pos_emb.data.copy_(
|
|
|
get_tensor("vision_encoder.encoder.model.visual.pos_embed")
|
|
|
)
|
|
|
|
|
|
for i in range(len(model.vision["blocks"])):
|
|
|
prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
|
|
|
|
|
|
|
|
|
model.vision["blocks"][i]["ln1"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.norm1.weight")
|
|
|
)
|
|
|
model.vision["blocks"][i]["ln1"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.norm1.bias")
|
|
|
)
|
|
|
model.vision["blocks"][i]["ln2"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.norm2.weight")
|
|
|
)
|
|
|
model.vision["blocks"][i]["ln2"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.norm2.bias")
|
|
|
)
|
|
|
|
|
|
|
|
|
model.vision["blocks"][i]["attn"]["qkv"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.attn.qkv.weight")
|
|
|
)
|
|
|
model.vision["blocks"][i]["attn"]["qkv"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.attn.qkv.bias")
|
|
|
)
|
|
|
model.vision["blocks"][i]["attn"]["proj"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.attn.proj.weight")
|
|
|
)
|
|
|
model.vision["blocks"][i]["attn"]["proj"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.attn.proj.bias")
|
|
|
)
|
|
|
|
|
|
|
|
|
model.vision["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.mlp.fc1.weight")
|
|
|
)
|
|
|
model.vision["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.mlp.fc1.bias")
|
|
|
)
|
|
|
model.vision["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.mlp.fc2.weight")
|
|
|
)
|
|
|
model.vision["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.mlp.fc2.bias")
|
|
|
)
|
|
|
|
|
|
model.vision["post_ln"].weight.data.copy_(
|
|
|
get_tensor("vision_encoder.encoder.model.visual.norm.weight")
|
|
|
)
|
|
|
model.vision["post_ln"].bias.data.copy_(
|
|
|
get_tensor("vision_encoder.encoder.model.visual.norm.bias")
|
|
|
)
|
|
|
|
|
|
model.vision["proj_mlp"]["fc1"].weight.data.copy_(
|
|
|
get_tensor("vision_encoder.projection.mlp.fc1.weight")
|
|
|
)
|
|
|
model.vision["proj_mlp"]["fc1"].bias.data.copy_(
|
|
|
get_tensor("vision_encoder.projection.mlp.fc1.bias")
|
|
|
)
|
|
|
model.vision["proj_mlp"]["fc2"].weight.data.copy_(
|
|
|
get_tensor("vision_encoder.projection.mlp.fc2.weight")
|
|
|
)
|
|
|
model.vision["proj_mlp"]["fc2"].bias.data.copy_(
|
|
|
get_tensor("vision_encoder.projection.mlp.fc2.bias")
|
|
|
)
|
|
|
|
|
|
|
|
|
model.text.wte.data.copy_(get_tensor("text_model.transformer.embd.wte.weight"))
|
|
|
|
|
|
for i in range(len(model.text["blocks"])):
|
|
|
prefix = f"text_model.transformer.h.{i}"
|
|
|
|
|
|
|
|
|
model.text["blocks"][i]["ln"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.ln.weight")
|
|
|
)
|
|
|
model.text["blocks"][i]["ln"].bias.data.copy_(get_tensor(f"{prefix}.ln.bias"))
|
|
|
|
|
|
|
|
|
model.text["blocks"][i]["attn"]["qkv"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.mixer.Wqkv.weight")
|
|
|
)
|
|
|
model.text["blocks"][i]["attn"]["qkv"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.mixer.Wqkv.bias")
|
|
|
)
|
|
|
model.text["blocks"][i]["attn"]["proj"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.mixer.out_proj.weight")
|
|
|
)
|
|
|
model.text["blocks"][i]["attn"]["proj"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.mixer.out_proj.bias")
|
|
|
)
|
|
|
|
|
|
|
|
|
model.text["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.mlp.fc1.weight")
|
|
|
)
|
|
|
model.text["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.mlp.fc1.bias")
|
|
|
)
|
|
|
model.text["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
|
|
|
get_tensor(f"{prefix}.mlp.fc2.weight")
|
|
|
)
|
|
|
model.text["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
|
|
|
get_tensor(f"{prefix}.mlp.fc2.bias")
|
|
|
)
|
|
|
|
|
|
model.text["post_ln"].weight.data.copy_(get_tensor("text_model.lm_head.ln.weight"))
|
|
|
model.text["post_ln"].bias.data.copy_(get_tensor("text_model.lm_head.ln.bias"))
|
|
|
|
|
|
model.text["lm_head"].weight.data.copy_(
|
|
|
get_tensor("text_model.lm_head.linear.weight")
|
|
|
)
|
|
|
model.text["lm_head"].bias.data.copy_(get_tensor("text_model.lm_head.linear.bias"))
|
|
|
|
|
|
|
|
|
model.region.coord_features.data.copy_(
|
|
|
get_tensor("region_model.coordinate_features.weight").T
|
|
|
)
|
|
|
model.region["coord_encoder"].weight.data.copy_(
|
|
|
get_tensor("region_model.coordinate_encoder.weight")
|
|
|
)
|
|
|
model.region["coord_encoder"].bias.data.copy_(
|
|
|
get_tensor("region_model.coordinate_encoder.bias")
|
|
|
)
|
|
|
|
|
|
model.region["coord_decoder"]["fc1"].weight.data.copy_(
|
|
|
get_tensor("region_model.coordinate_decoder.fc1.weight")
|
|
|
)
|
|
|
model.region["coord_decoder"]["fc1"].bias.data.copy_(
|
|
|
get_tensor("region_model.coordinate_decoder.fc1.bias")
|
|
|
)
|
|
|
model.region["coord_decoder"]["fc2"].weight.data.copy_(
|
|
|
get_tensor("region_model.coordinate_decoder.fc2.weight")
|
|
|
)
|
|
|
model.region["coord_decoder"]["fc2"].bias.data.copy_(
|
|
|
get_tensor("region_model.coordinate_decoder.fc2.bias")
|
|
|
)
|
|
|
|
|
|
model.region.size_features.data.copy_(
|
|
|
get_tensor("region_model.size_features.weight").T
|
|
|
)
|
|
|
model.region["size_encoder"].weight.data.copy_(
|
|
|
get_tensor("region_model.size_encoder.weight")
|
|
|
)
|
|
|
model.region["size_encoder"].bias.data.copy_(
|
|
|
get_tensor("region_model.size_encoder.bias")
|
|
|
)
|
|
|
|
|
|
model.region["size_decoder"]["fc1"].weight.data.copy_(
|
|
|
get_tensor("region_model.size_decoder.fc1.weight")
|
|
|
)
|
|
|
model.region["size_decoder"]["fc1"].bias.data.copy_(
|
|
|
get_tensor("region_model.size_decoder.fc1.bias")
|
|
|
)
|
|
|
model.region["size_decoder"]["fc2"].weight.data.copy_(
|
|
|
get_tensor("region_model.size_decoder.fc2.weight")
|
|
|
)
|
|
|
model.region["size_decoder"]["fc2"].bias.data.copy_(
|
|
|
get_tensor("region_model.size_decoder.fc2.bias")
|
|
|
)
|
|
|
|
|
|
|
|
|
def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
|
|
|
"""Load weights from a safetensors file into a MoondreamModel instance."""
|
|
|
with safetensors_open(weights_file) as get_tensor:
|
|
|
|
|
|
name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
|
|
|
_load_weights(lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model)
|
|
|
|
|
|
|
|
|
def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
|
|
|
"""Load weights from a PyTorch file into a MoondreamModel instance."""
|
|
|
device = str(torch.empty(0).device)
|
|
|
tensors = torch.load(weights_file, map_location=device, weights_only=True)
|
|
|
tensors = {
|
|
|
k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
|
|
|
for k, v in tensors.items()
|
|
|
}
|
|
|
_load_weights(lambda x: tensors[x], model)
|
|
|
|
|
|
|
|
|
def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
|
|
|
"""
|
|
|
Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance.
|
|
|
|
|
|
Args:
|
|
|
weights_file: Path to weights file (either .safetensors or .pt)
|
|
|
model: MoondreamModel instance to load weights into
|
|
|
"""
|
|
|
if weights_file.endswith(".safetensors"):
|
|
|
load_weights_from_safetensors(weights_file, model)
|
|
|
else:
|
|
|
load_weights_from_pt(weights_file, model)
|
|
|
|
|
|
|
|
|
for param in model.parameters():
|
|
|
param.data = param.data.contiguous()
|
|
|
|