# src/utils/embedding_injector.py import torch.nn as nn from transformers import GPT2LMHeadModel from utils.emb_comp_utils import get_peft_embedding def inject_hashed_embedding(base_model: GPT2LMHeadModel, tokenizer) -> GPT2LMHeadModel: """ Replaces the base model's input embedding layer with a hashed embedding layer, using the get_peft_embedding function. Returns the modified model. """ print("[ℹ] Injecting hashed embedding layer...") # Create hashed embedding layer new_embedding = get_peft_embedding( vocab_size=tokenizer.vocab_size, hidden_size=base_model.config.hidden_size, num_hashes=2, num_buckets=8192, # or 4096 if needed device=base_model.device ) # Replace original embedding layer base_model.transformer.wte = new_embedding return base_model