Spaces:
Running
Running
File size: 6,130 Bytes
1b51bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, TaskType, PeftModel, get_peft_model
class TextEmbeddingModel(nn.Module):
"""Wrapper around a Hugging Face model with optional LoRA adapters."""
def __init__(
self,
model_name,
output_hidden_states=False,
lora=False,
infer=False,
use_pooling="average",
lora_r=128,
lora_alpha=256,
lora_dropout=0,
adapter_path=None,
):
super(TextEmbeddingModel, self).__init__()
self.model_name = model_name
self.use_pooling = use_pooling
self.lora = lora
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
model_kwargs = {"trust_remote_code": True}
if output_hidden_states:
model_kwargs["output_hidden_states"] = True
self.model = AutoModel.from_pretrained(model_name, **model_kwargs)
if self.lora:
peft_config = LoraConfig(
peft_type=TaskType.FEATURE_EXTRACTION,
inference_mode=infer,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.model = get_peft_model(self.model, peft_config)
if adapter_path is not None:
self.load_adapter(adapter_path, is_trainable=not infer)
else:
self.model.print_trainable_parameters()
elif adapter_path is not None:
self.model = AutoModel.from_pretrained(adapter_path, **model_kwargs)
def pooling(self, model_output, attention_mask, hidden_states=False):
if hidden_states:
if self.use_pooling == "average":
model_output.masked_fill(~attention_mask[None, ..., None].bool(), 0.0)
emb = model_output.sum(dim=2) / attention_mask.sum(dim=1)[..., None]
elif self.use_pooling == "max":
emb = model_output.masked_fill(~attention_mask[None, ..., None].bool(), float("-inf"))
emb, _ = emb.max(dim=2)
elif self.use_pooling == "cls":
emb = model_output[:, :, 0]
else:
raise ValueError("Pooling method not supported")
emb = emb.permute(1, 0, 2)
else:
if self.use_pooling == "average":
model_output.masked_fill(~attention_mask[..., None].bool(), 0.0)
emb = model_output.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif self.use_pooling == "max":
emb = model_output.masked_fill(~attention_mask[..., None].bool(), float("-inf"))
emb, _ = emb.max(dim=1)
elif self.use_pooling == "cls":
emb = model_output[:, 0]
else:
raise ValueError("Pooling method not supported")
return emb
def forward(self, encoded_batch, hidden_states=False, retrun_all_emb=False):
if "t5" in self.model_name.lower():
input_ids = encoded_batch['input_ids']
decoder_input_ids = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
model_output = self.model(**encoded_batch,
decoder_input_ids=decoder_input_ids)
else:
model_output = self.model(**encoded_batch)
if isinstance(model_output, tuple):
model_output = model_output[0]
if isinstance(model_output, dict):
if hidden_states:
model_output = model_output["hidden_states"]
model_output = torch.stack(model_output, dim=0)
else:
model_output = model_output["last_hidden_state"]
emb = self.pooling(model_output, encoded_batch['attention_mask'], hidden_states)
if retrun_all_emb:
return emb, model_output
return emb
def save_pretrained(self, save_directory: str, save_tokenizer: bool = True):
os.makedirs(save_directory, exist_ok=True)
if isinstance(self.model, PeftModel):
self.model.save_pretrained(save_directory)
else:
self.model.save_pretrained(save_directory)
if save_tokenizer:
self.tokenizer.save_pretrained(save_directory)
def load_adapter(self, adapter_path: str, is_trainable: bool = False):
if not self.lora or not isinstance(self.model, PeftModel):
raise ValueError("LoRA is not enabled for this model instance.")
self.model = PeftModel.from_pretrained(
self.model.base_model,
adapter_path,
is_trainable=is_trainable,
)
self.model.print_trainable_parameters()
def merge_and_unload(self):
if not isinstance(self.model, PeftModel):
raise ValueError("The current model does not contain a LoRA adapter to merge.")
merged_model = self.model.merge_and_unload()
return merged_model
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, hidden_size,num_labels):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, num_labels)
def forward(self, x):
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
return x
class TextClassificationModel(nn.Module):
def __init__(self, opt,dim=2):
super(TextClassificationModel, self).__init__()
self.model = TextEmbeddingModel(opt.model_name,lora=True,use_pooling=opt.pooling,\
lora_r=opt.lora_r,lora_alpha=opt.lora_alpha,infer=True)
self.root_classfier = nn.Linear(opt.embedding_dim, dim)
def forward(self, encoded_batch):
q = self.model(encoded_batch)
out = self.root_classfier(q)
return out
|