lm_head.weight issue
Recommendation:
Even though tie_word_embeddings=True is set in the config, the weights are not actually tied when the model is loaded via AutoModel.from_pretrained().
It would be very helpful to add a brief inline comment explaining the problem and the required fix:
model.language_model.lm_head.weight = model.language_model.model.embed_tokens.weight
to save others significant debugging time and prevent confusion when working with tied embeddings
thank you
Hi @nur-dev , thank you for reporting this issue!
I believe this is a problem with the transformers version. We trained and tested the model on transformers==4.57.1 so far. You should not have any problems loading the model:
model = AutoModel.from_pretrained(
"issai/Qolda",
dtype=torch.bfloat16,
use_flash_attn=True, # optional
trust_remote_code=True,
device_map="auto",
).eval()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
and verifying that the embeddings and lm head are indeed tied:
model.language_model.lm_head.weight.data_ptr() == model.language_model.model.embed_tokens.weight.data_ptr() # should return True
However, it seems that transformers==5.x introduced a post_init() step in models' __init__() method (issue). It explicitly registers the tied-weight mappings, including the all_tied_weights_keys object.
We don't have that in our remote_code files yet. As it seems to be mandatory for newer transformers, we will update the scripts later (there are actually a few more compatibility issues).
For now, the following patches on model loading should be sufficient for inference:
import contextlib
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
@contextlib.contextmanager
def patch_meta_tensor_item():
original_item = torch.Tensor.item
def safe_item(self):
if self.device.type == "meta":
return 0.0
return original_item(self)
torch.Tensor.item = safe_item
try:
yield
finally:
torch.Tensor.item = original_item
@contextlib.contextmanager
def patch_missing_post_init():
original_getattr = nn.Module.__getattr__
def patched_getattr(self, name):
if name == "all_tied_weights_keys" and isinstance(self, PreTrainedModel):
keys = self.get_expanded_tied_weights_keys(all_submodels=True)
self.all_tied_weights_keys = keys
return keys
return original_getattr(self, name)
nn.Module.__getattr__ = patched_getattr
try:
yield
finally:
nn.Module.__getattr__ = original_getattr
def load_model_and_tokenizer(path="issai/Qolda"):
with patch_meta_tensor_item(), patch_missing_post_init():
model = AutoModel.from_pretrained(
path,
dtype=torch.bfloat16,
use_flash_attn=True,
trust_remote_code=True,
device_map="auto",
).eval()
tokenizer = AutoTokenizer.from_pretrained(
path, trust_remote_code=True, use_fast=False
)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
And verifying the weight tying again:
model.language_model.lm_head.weight.data_ptr() == model.language_model.model.embed_tokens.weight.data_ptr()
Thank you again for pointing out this issue! If you had this problem even with the older version of transformers, let us know.