Authentica / detree /model /text_embedding.py
MAS-AI-0000's picture
Upload 2 files
1b51bf6 verified
raw
history blame
6.13 kB
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, TaskType, PeftModel, get_peft_model
class TextEmbeddingModel(nn.Module):
"""Wrapper around a Hugging Face model with optional LoRA adapters."""
def __init__(
self,
model_name,
output_hidden_states=False,
lora=False,
infer=False,
use_pooling="average",
lora_r=128,
lora_alpha=256,
lora_dropout=0,
adapter_path=None,
):
super(TextEmbeddingModel, self).__init__()
self.model_name = model_name
self.use_pooling = use_pooling
self.lora = lora
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
model_kwargs = {"trust_remote_code": True}
if output_hidden_states:
model_kwargs["output_hidden_states"] = True
self.model = AutoModel.from_pretrained(model_name, **model_kwargs)
if self.lora:
peft_config = LoraConfig(
peft_type=TaskType.FEATURE_EXTRACTION,
inference_mode=infer,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.model = get_peft_model(self.model, peft_config)
if adapter_path is not None:
self.load_adapter(adapter_path, is_trainable=not infer)
else:
self.model.print_trainable_parameters()
elif adapter_path is not None:
self.model = AutoModel.from_pretrained(adapter_path, **model_kwargs)
def pooling(self, model_output, attention_mask, hidden_states=False):
if hidden_states:
if self.use_pooling == "average":
model_output.masked_fill(~attention_mask[None, ..., None].bool(), 0.0)
emb = model_output.sum(dim=2) / attention_mask.sum(dim=1)[..., None]
elif self.use_pooling == "max":
emb = model_output.masked_fill(~attention_mask[None, ..., None].bool(), float("-inf"))
emb, _ = emb.max(dim=2)
elif self.use_pooling == "cls":
emb = model_output[:, :, 0]
else:
raise ValueError("Pooling method not supported")
emb = emb.permute(1, 0, 2)
else:
if self.use_pooling == "average":
model_output.masked_fill(~attention_mask[..., None].bool(), 0.0)
emb = model_output.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
elif self.use_pooling == "max":
emb = model_output.masked_fill(~attention_mask[..., None].bool(), float("-inf"))
emb, _ = emb.max(dim=1)
elif self.use_pooling == "cls":
emb = model_output[:, 0]
else:
raise ValueError("Pooling method not supported")
return emb
def forward(self, encoded_batch, hidden_states=False, retrun_all_emb=False):
if "t5" in self.model_name.lower():
input_ids = encoded_batch['input_ids']
decoder_input_ids = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
model_output = self.model(**encoded_batch,
decoder_input_ids=decoder_input_ids)
else:
model_output = self.model(**encoded_batch)
if isinstance(model_output, tuple):
model_output = model_output[0]
if isinstance(model_output, dict):
if hidden_states:
model_output = model_output["hidden_states"]
model_output = torch.stack(model_output, dim=0)
else:
model_output = model_output["last_hidden_state"]
emb = self.pooling(model_output, encoded_batch['attention_mask'], hidden_states)
if retrun_all_emb:
return emb, model_output
return emb
def save_pretrained(self, save_directory: str, save_tokenizer: bool = True):
os.makedirs(save_directory, exist_ok=True)
if isinstance(self.model, PeftModel):
self.model.save_pretrained(save_directory)
else:
self.model.save_pretrained(save_directory)
if save_tokenizer:
self.tokenizer.save_pretrained(save_directory)
def load_adapter(self, adapter_path: str, is_trainable: bool = False):
if not self.lora or not isinstance(self.model, PeftModel):
raise ValueError("LoRA is not enabled for this model instance.")
self.model = PeftModel.from_pretrained(
self.model.base_model,
adapter_path,
is_trainable=is_trainable,
)
self.model.print_trainable_parameters()
def merge_and_unload(self):
if not isinstance(self.model, PeftModel):
raise ValueError("The current model does not contain a LoRA adapter to merge.")
merged_model = self.model.merge_and_unload()
return merged_model
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, hidden_size,num_labels):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, num_labels)
def forward(self, x):
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
return x
class TextClassificationModel(nn.Module):
def __init__(self, opt,dim=2):
super(TextClassificationModel, self).__init__()
self.model = TextEmbeddingModel(opt.model_name,lora=True,use_pooling=opt.pooling,\
lora_r=opt.lora_r,lora_alpha=opt.lora_alpha,infer=True)
self.root_classfier = nn.Linear(opt.embedding_dim, dim)
def forward(self, encoded_batch):
q = self.model(encoded_batch)
out = self.root_classfier(q)
return out