lm_head.weight issue

#1
by nur-dev - opened

Recommendation:

Even though tie_word_embeddings=True is set in the config, the weights are not actually tied when the model is loaded via AutoModel.from_pretrained().

It would be very helpful to add a brief inline comment explaining the problem and the required fix:

model.language_model.lm_head.weight = model.language_model.model.embed_tokens.weight

to save others significant debugging time and prevent confusion when working with tied embeddings
thank you

Institute of Smart Systems and Artificial Intelligence, Nazarbayev University org
β€’
edited 18 days ago

Hi @nur-dev , thank you for reporting this issue!

I believe this is a problem with the transformers version. We trained and tested the model on transformers==4.57.1 so far. You should not have any problems loading the model:

model = AutoModel.from_pretrained(
    "issai/Qolda",
    dtype=torch.bfloat16,
    use_flash_attn=True, # optional
    trust_remote_code=True,
    device_map="auto",
).eval()

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

and verifying that the embeddings and lm head are indeed tied:

model.language_model.lm_head.weight.data_ptr() == model.language_model.model.embed_tokens.weight.data_ptr() # should return True

However, it seems that transformers==5.x introduced a post_init() step in models' __init__() method (issue). It explicitly registers the tied-weight mappings, including the all_tied_weights_keys object.

We don't have that in our remote_code files yet. As it seems to be mandatory for newer transformers, we will update the scripts later (there are actually a few more compatibility issues).

For now, the following patches on model loading should be sufficient for inference:

import contextlib

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel

@contextlib.contextmanager
def patch_meta_tensor_item():
    original_item = torch.Tensor.item

    def safe_item(self):
        if self.device.type == "meta":
            return 0.0
        return original_item(self)

    torch.Tensor.item = safe_item
    try:
        yield
    finally:
        torch.Tensor.item = original_item

@contextlib.contextmanager
def patch_missing_post_init():
    original_getattr = nn.Module.__getattr__

    def patched_getattr(self, name):
        if name == "all_tied_weights_keys" and isinstance(self, PreTrainedModel):
            keys = self.get_expanded_tied_weights_keys(all_submodels=True)
            self.all_tied_weights_keys = keys
            return keys
        return original_getattr(self, name)

    nn.Module.__getattr__ = patched_getattr
    try:
        yield
    finally:
        nn.Module.__getattr__ = original_getattr


def load_model_and_tokenizer(path="issai/Qolda"):
    with patch_meta_tensor_item(), patch_missing_post_init():
        model = AutoModel.from_pretrained(
            path,
            dtype=torch.bfloat16,
            use_flash_attn=True,
            trust_remote_code=True,
            device_map="auto",
        ).eval()

    tokenizer = AutoTokenizer.from_pretrained(
        path, trust_remote_code=True, use_fast=False
    )

    return model, tokenizer

model, tokenizer = load_model_and_tokenizer()

And verifying the weight tying again:

model.language_model.lm_head.weight.data_ptr() == model.language_model.model.embed_tokens.weight.data_ptr()

Thank you again for pointing out this issue! If you had this problem even with the older version of transformers, let us know.

Sign up or log in to comment