File size: 5,425 Bytes
ff83b1d ce60410 ff83b1d ce60410 ff83b1d ce60410 ff83b1d ce60410 ff83b1d ce60410 ff83b1d ce60410 ff83b1d ce60410 ff83b1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
from torch import nn
from torch import Tensor, LongTensor
from transformers import AutoTokenizer, AutoModel
try:
from peft import LoraConfig, get_peft_model
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
from typing import Optional, List
class WordTransformerEncoder(nn.Module):
"""
Encodes sentences into word-level embeddings using a pretrained MLM transformer.
Optionally enables LoRA fine-tuning adapters.
"""
def __init__(
self,
model_name: str,
use_lora: bool = False,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: Optional[List[str]] = None
):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
if use_lora:
if not PEFT_AVAILABLE:
raise ImportError("peft is required for LoRA fine-tuning. Install with `pip install peft`.")
if lora_target_modules is None:
lora_target_modules = ["query", "value"]
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="FEATURE_EXTRACTION"
)
self.model = get_peft_model(self.model, lora_config)
print(f"LoRA enabled: r={lora_r}, alpha={lora_alpha}, target_modules={lora_target_modules}")
def forward(self, words: list[list[str]]) -> Tensor:
"""
Build words embeddings.
- Tokenizes input sentences into subtokens.
- Passes the subtokens through the pre-trained transformer model.
- Aggregates subtoken embeddings into word embeddings using mean pooling.
"""
batch_size = len(words)
# BPE tokenization: split words into subtokens, e.g. ['kidding'] -> ['▁ki', 'dding'].
subtokens = self.tokenizer(
words,
padding=True,
truncation=True,
is_split_into_words=True,
return_tensors='pt'
)
subtokens = subtokens.to(self.model.device)
# Index words from 1 and reserve 0 for special subtokens (e.g. <s>, </s>, padding, etc.).
# Such numeration makes a following aggregation easier.
words_ids = torch.stack([
torch.tensor(
[word_id + 1 if word_id is not None else 0 for word_id in subtokens.word_ids(batch_idx)],
dtype=torch.long,
device=self.model.device
)
for batch_idx in range(batch_size)
])
# Run model and extract subtokens embeddings from the last layer.
subtokens_embeddings = self.model(**subtokens).last_hidden_state
# Aggreate subtokens embeddings into words embeddings.
# [batch_size, n_words, embedding_size]
words_emeddings = self._aggregate_subtokens_embeddings(subtokens_embeddings, words_ids)
return words_emeddings
def _aggregate_subtokens_embeddings(
self,
subtokens_embeddings: Tensor, # [batch_size, n_subtokens, embedding_size]
words_ids: LongTensor # [batch_size, n_subtokens]
) -> Tensor:
"""
Aggregate subtoken embeddings into word embeddings by averaging.
This method ensures that multiple subtokens corresponding to a single word are combined
into a single embedding.
"""
batch_size, n_subtokens, embedding_size = subtokens_embeddings.shape
# The number of words in a sentence plus an "auxiliary" word in the beginnig.
n_words = torch.max(words_ids) + 1
words_embeddings = torch.zeros(
size=(batch_size, n_words, embedding_size),
dtype=subtokens_embeddings.dtype,
device=self.model.device
)
words_ids_expanded = words_ids.unsqueeze(-1).expand(batch_size, n_subtokens, embedding_size)
# Use scatter_reduce_ to average embeddings of subtokens corresponding to the same word.
# All the padding and special subtokens will be aggregated into an "auxiliary" first embedding,
# namely into words_embeddings[:, 0, :].
words_embeddings.scatter_reduce_(
dim=1,
index=words_ids_expanded,
src=subtokens_embeddings,
reduce="mean",
include_self=False
)
# Now remove the auxiliary word in the beginning.
words_embeddings = words_embeddings[:, 1:, :]
return words_embeddings
def get_embedding_size(self) -> int:
"""Returns the embedding size of the transformer model, e.g. 768 for BERT."""
return self.model.config.hidden_size
def get_embeddings_layer(self):
"""Returns the embeddings model."""
return self.model.embeddings
def get_transformer_layers(self) -> list[nn.Module]:
"""
Return a flat list of all transformer-*block* layers, excluding embeddings/poolers, etc.
"""
layers = []
for sub in self.model.modules():
# find all ModuleLists (these always hold the actual block layers)
if isinstance(sub, nn.ModuleList):
layers.extend(list(sub))
return layers |