Upload ConlluTokenClassificationPipeline
Browse files- encoder.py +7 -15
- model.safetensors +2 -2
encoder.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
from torch import nn
|
| 3 |
from torch import Tensor, LongTensor
|
| 4 |
|
| 5 |
-
from transformers import AutoTokenizer,
|
| 6 |
|
| 7 |
try:
|
| 8 |
from peft import LoraConfig, get_peft_model
|
|
@@ -28,30 +28,23 @@ class WordTransformerEncoder(nn.Module):
|
|
| 28 |
):
|
| 29 |
super().__init__()
|
| 30 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 31 |
-
self.model =
|
| 32 |
|
| 33 |
if use_lora:
|
| 34 |
if not PEFT_AVAILABLE:
|
| 35 |
raise ImportError("peft is required for LoRA fine-tuning. Install with `pip install peft`.")
|
| 36 |
if lora_target_modules is None:
|
| 37 |
-
|
| 38 |
lora_config = LoraConfig(
|
| 39 |
r=lora_r,
|
| 40 |
lora_alpha=lora_alpha,
|
| 41 |
target_modules=lora_target_modules,
|
| 42 |
lora_dropout=lora_dropout,
|
| 43 |
bias="none",
|
| 44 |
-
task_type="
|
| 45 |
)
|
| 46 |
-
print("DEBUG: model class =", type(self.model))
|
| 47 |
-
for name, module in self.model.named_modules():
|
| 48 |
-
if "proj" in name:
|
| 49 |
-
print("DEBUG: found module", name, "->", module)
|
| 50 |
self.model = get_peft_model(self.model, lora_config)
|
| 51 |
-
print("LoRA enabled
|
| 52 |
-
for name, param in self.model.named_parameters():
|
| 53 |
-
if "lora" in name:
|
| 54 |
-
print("LoRA param:", name, param.shape)
|
| 55 |
|
| 56 |
def forward(self, words: list[list[str]]) -> Tensor:
|
| 57 |
"""
|
|
@@ -84,8 +77,7 @@ class WordTransformerEncoder(nn.Module):
|
|
| 84 |
])
|
| 85 |
|
| 86 |
# Run model and extract subtokens embeddings from the last layer.
|
| 87 |
-
|
| 88 |
-
subtokens_embeddings = outputs.hidden_states[-1]
|
| 89 |
|
| 90 |
# Aggreate subtokens embeddings into words embeddings.
|
| 91 |
# [batch_size, n_words, embedding_size]
|
|
@@ -134,7 +126,7 @@ class WordTransformerEncoder(nn.Module):
|
|
| 134 |
|
| 135 |
def get_embeddings_layer(self):
|
| 136 |
"""Returns the embeddings model."""
|
| 137 |
-
return self.model.
|
| 138 |
|
| 139 |
def get_transformer_layers(self) -> list[nn.Module]:
|
| 140 |
"""
|
|
|
|
| 2 |
from torch import nn
|
| 3 |
from torch import Tensor, LongTensor
|
| 4 |
|
| 5 |
+
from transformers import AutoTokenizer, AutoModel
|
| 6 |
|
| 7 |
try:
|
| 8 |
from peft import LoraConfig, get_peft_model
|
|
|
|
| 28 |
):
|
| 29 |
super().__init__()
|
| 30 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 31 |
+
self.model = AutoModel.from_pretrained(model_name)
|
| 32 |
|
| 33 |
if use_lora:
|
| 34 |
if not PEFT_AVAILABLE:
|
| 35 |
raise ImportError("peft is required for LoRA fine-tuning. Install with `pip install peft`.")
|
| 36 |
if lora_target_modules is None:
|
| 37 |
+
lora_target_modules = ["query", "value"]
|
| 38 |
lora_config = LoraConfig(
|
| 39 |
r=lora_r,
|
| 40 |
lora_alpha=lora_alpha,
|
| 41 |
target_modules=lora_target_modules,
|
| 42 |
lora_dropout=lora_dropout,
|
| 43 |
bias="none",
|
| 44 |
+
task_type="FEATURE_EXTRACTION"
|
| 45 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
self.model = get_peft_model(self.model, lora_config)
|
| 47 |
+
print(f"LoRA enabled: r={lora_r}, alpha={lora_alpha}, target_modules={lora_target_modules}")
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def forward(self, words: list[list[str]]) -> Tensor:
|
| 50 |
"""
|
|
|
|
| 77 |
])
|
| 78 |
|
| 79 |
# Run model and extract subtokens embeddings from the last layer.
|
| 80 |
+
subtokens_embeddings = self.model(**subtokens).last_hidden_state
|
|
|
|
| 81 |
|
| 82 |
# Aggreate subtokens embeddings into words embeddings.
|
| 83 |
# [batch_size, n_words, embedding_size]
|
|
|
|
| 126 |
|
| 127 |
def get_embeddings_layer(self):
|
| 128 |
"""Returns the embeddings model."""
|
| 129 |
+
return self.model.embeddings
|
| 130 |
|
| 131 |
def get_transformer_layers(self) -> list[nn.Module]:
|
| 132 |
"""
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:417921753fb771613766e89226a2fcbdcd259ac9a4c9acbfa55cce7ccb4e1222
|
| 3 |
+
size 1134190536
|