Hanrui / SpecForge /specforge /modeling /target /target_utils.py
Lekr0's picture
Add files using upload-large-folder tool
7a60a87 verified
import gc
import glob
import json
import os
from typing import Optional
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
class TargetEmbeddingsAndHead(nn.Module):
"""
Efficiently loads only the embedding layer and lm_head from a pretrained model.
Handles safetensors slicing and Weight Tying correctly.
"""
def __init__(self, config):
super().__init__()
self.config = config
# Support for MLLMs with separate text_config
if hasattr(config, "text_config"):
self.embed_tokens = nn.Embedding(
config.text_config.vocab_size,
config.text_config.hidden_size,
padding_idx=config.text_config.pad_token_id,
)
self.lm_head = nn.Linear(
config.text_config.hidden_size,
config.text_config.vocab_size,
bias=False,
)
else:
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@classmethod
def from_pretrained(
cls,
model_path: str,
embed_key: Optional[str] = None,
lm_head_key: Optional[str] = None,
cache_dir: Optional[str] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
trust_remote_code: bool = False,
) -> "TargetEmbeddingsAndHead":
# 1. Load Config
config = AutoConfig.from_pretrained(
model_path, cache_dir=cache_dir, trust_remote_code=trust_remote_code
)
instance = cls(config)
if embed_key is None:
embed_key = "model.embed_tokens.weight"
if lm_head_key is None:
lm_head_key = "lm_head.weight"
# 2. Resolve Model Path
local_model_path = model_path
if not os.path.exists(local_model_path):
try:
local_model_path = snapshot_download(
repo_id=model_path,
cache_dir=cache_dir,
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.model"],
)
except Exception as e:
print(f"Warning: Snapshot download failed or path check failed: {e}")
# 3. Handle Weight Tying
tie_weights = getattr(config, "tie_word_embeddings", False)
# 4. Load Weights
instance._load_weights(local_model_path, embed_key, lm_head_key, tie_weights)
# 5. Move to Device & Freeze
instance.to(device=device, dtype=dtype)
instance.eval()
instance.requires_grad_(False)
return instance
def _load_weights(
self, model_path: str, embed_key: str, lm_head_key: str, tie_weights: bool
):
index_files = glob.glob(os.path.join(model_path, "*.index.json"))
weight_map = {}
files_to_load = {}
if index_files:
with open(index_files[0], "r") as f:
index = json.load(f)
weight_map = index.get("weight_map", {})
if embed_key in weight_map:
files_to_load[embed_key] = weight_map[embed_key]
else:
raise ValueError(
f"Embedding key '{embed_key}' not found in weight map."
)
if not tie_weights:
if lm_head_key in weight_map:
files_to_load[lm_head_key] = weight_map[lm_head_key]
else:
print(
f"Warning: {lm_head_key} not found. Ensure model doesn't use tied weights manually."
)
else:
safetensors = glob.glob(os.path.join(model_path, "*.safetensors"))
bins = glob.glob(os.path.join(model_path, "*.bin"))
target_file = safetensors[0] if safetensors else (bins[0] if bins else None)
if not target_file:
raise FileNotFoundError("No checkpoint found.")
files_to_load[embed_key] = os.path.basename(target_file)
if not tie_weights:
files_to_load[lm_head_key] = os.path.basename(target_file)
loaded_keys = set()
file_to_keys_map = {}
for key, filename in files_to_load.items():
full_path = os.path.join(model_path, filename)
if full_path not in file_to_keys_map:
file_to_keys_map[full_path] = []
file_to_keys_map[full_path].append(key)
for file_path, keys in file_to_keys_map.items():
self._load_file_content(file_path, keys, embed_key, lm_head_key)
loaded_keys.update(keys)
if tie_weights:
print(
"Weight tying detected: Sharing weights between Embeddings and LM Head."
)
self.lm_head.weight = self.embed_tokens.weight
if embed_key not in loaded_keys:
raise RuntimeError("Failed to load embeddings.")
if not tie_weights and lm_head_key not in loaded_keys:
print(
"Warning: LM Head weights were not found (and tie_weights is False). Head is random."
)
def _load_file_content(
self,
file_path: str,
keys_to_extract: list,
target_embed_key: str,
target_head_key: str,
):
"""Helper to load specific keys from a file"""
print(f"Loading {keys_to_extract} from {os.path.basename(file_path)}...")
state_dict_part = {}
if file_path.endswith(".safetensors"):
with safe_open(file_path, framework="pt") as f:
for k in keys_to_extract:
if k in f.keys():
state_dict_part[k] = f.get_tensor(k)
else:
print(
f"Warning: Loading .bin file {os.path.basename(file_path)} into RAM. Convert to safetensors for efficiency."
)
full_state = torch.load(file_path, map_location="cpu")
for k in keys_to_extract:
if k in full_state:
state_dict_part[k] = full_state[k]
del full_state
gc.collect()
for k, tensor in state_dict_part.items():
if k == target_embed_key:
self.embed_tokens.weight.data.copy_(tensor)
print(" -> Loaded Embeddings")
elif k == target_head_key:
if tensor.shape == self.lm_head.weight.data.shape:
self.lm_head.weight.data.copy_(tensor)
print(" -> Loaded LM Head")
else:
raise RuntimeError(
f"Shape mismatch for {k}. Expected {self.lm_head.weight.shape}, got {tensor.shape}"
)