File size: 850 Bytes
549c270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# src/utils/embedding_injector.py

import torch.nn as nn
from transformers import GPT2LMHeadModel
from utils.emb_comp_utils import get_peft_embedding

def inject_hashed_embedding(base_model: GPT2LMHeadModel, tokenizer) -> GPT2LMHeadModel:
    """
    Replaces the base model's input embedding layer with a hashed embedding layer,
    using the get_peft_embedding function. Returns the modified model.
    """
    print("[ℹ] Injecting hashed embedding layer...")

    # Create hashed embedding layer
    new_embedding = get_peft_embedding(
        vocab_size=tokenizer.vocab_size,
        hidden_size=base_model.config.hidden_size,
        num_hashes=2,
        num_buckets=8192,  # or 4096 if needed
        device=base_model.device
    )

    # Replace original embedding layer
    base_model.transformer.wte = new_embedding

    return base_model