Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModel | |
| from peft import LoraConfig, TaskType, PeftModel, get_peft_model | |
| class TextEmbeddingModel(nn.Module): | |
| """Wrapper around a Hugging Face model with optional LoRA adapters.""" | |
| def __init__( | |
| self, | |
| model_name, | |
| output_hidden_states=False, | |
| lora=False, | |
| infer=False, | |
| use_pooling="average", | |
| lora_r=128, | |
| lora_alpha=256, | |
| lora_dropout=0, | |
| adapter_path=None, | |
| ): | |
| super(TextEmbeddingModel, self).__init__() | |
| self.model_name = model_name | |
| self.use_pooling = use_pooling | |
| self.lora = lora | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model_kwargs = {"trust_remote_code": True} | |
| if output_hidden_states: | |
| model_kwargs["output_hidden_states"] = True | |
| self.model = AutoModel.from_pretrained(model_name, **model_kwargs) | |
| if self.lora: | |
| peft_config = LoraConfig( | |
| peft_type=TaskType.FEATURE_EXTRACTION, | |
| inference_mode=infer, | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| ) | |
| self.model = get_peft_model(self.model, peft_config) | |
| if adapter_path is not None: | |
| self.load_adapter(adapter_path, is_trainable=not infer) | |
| else: | |
| self.model.print_trainable_parameters() | |
| elif adapter_path is not None: | |
| self.model = AutoModel.from_pretrained(adapter_path, **model_kwargs) | |
| def pooling(self, model_output, attention_mask, hidden_states=False): | |
| if hidden_states: | |
| if self.use_pooling == "average": | |
| model_output.masked_fill(~attention_mask[None, ..., None].bool(), 0.0) | |
| emb = model_output.sum(dim=2) / attention_mask.sum(dim=1)[..., None] | |
| elif self.use_pooling == "max": | |
| emb = model_output.masked_fill(~attention_mask[None, ..., None].bool(), float("-inf")) | |
| emb, _ = emb.max(dim=2) | |
| elif self.use_pooling == "cls": | |
| emb = model_output[:, :, 0] | |
| else: | |
| raise ValueError("Pooling method not supported") | |
| emb = emb.permute(1, 0, 2) | |
| else: | |
| if self.use_pooling == "average": | |
| model_output.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| emb = model_output.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
| elif self.use_pooling == "max": | |
| emb = model_output.masked_fill(~attention_mask[..., None].bool(), float("-inf")) | |
| emb, _ = emb.max(dim=1) | |
| elif self.use_pooling == "cls": | |
| emb = model_output[:, 0] | |
| else: | |
| raise ValueError("Pooling method not supported") | |
| return emb | |
| def forward(self, encoded_batch, hidden_states=False, retrun_all_emb=False): | |
| if "t5" in self.model_name.lower(): | |
| input_ids = encoded_batch['input_ids'] | |
| decoder_input_ids = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) | |
| model_output = self.model(**encoded_batch, | |
| decoder_input_ids=decoder_input_ids) | |
| else: | |
| model_output = self.model(**encoded_batch) | |
| if isinstance(model_output, tuple): | |
| model_output = model_output[0] | |
| if isinstance(model_output, dict): | |
| if hidden_states: | |
| model_output = model_output["hidden_states"] | |
| model_output = torch.stack(model_output, dim=0) | |
| else: | |
| model_output = model_output["last_hidden_state"] | |
| emb = self.pooling(model_output, encoded_batch['attention_mask'], hidden_states) | |
| if retrun_all_emb: | |
| return emb, model_output | |
| return emb | |
| def save_pretrained(self, save_directory: str, save_tokenizer: bool = True): | |
| os.makedirs(save_directory, exist_ok=True) | |
| if isinstance(self.model, PeftModel): | |
| self.model.save_pretrained(save_directory) | |
| else: | |
| self.model.save_pretrained(save_directory) | |
| if save_tokenizer: | |
| self.tokenizer.save_pretrained(save_directory) | |
| def load_adapter(self, adapter_path: str, is_trainable: bool = False): | |
| if not self.lora or not isinstance(self.model, PeftModel): | |
| raise ValueError("LoRA is not enabled for this model instance.") | |
| self.model = PeftModel.from_pretrained( | |
| self.model.base_model, | |
| adapter_path, | |
| is_trainable=is_trainable, | |
| ) | |
| self.model.print_trainable_parameters() | |
| def merge_and_unload(self): | |
| if not isinstance(self.model, PeftModel): | |
| raise ValueError("The current model does not contain a LoRA adapter to merge.") | |
| merged_model = self.model.merge_and_unload() | |
| return merged_model | |
| class ClassificationHead(nn.Module): | |
| """Head for sentence-level classification tasks.""" | |
| def __init__(self, hidden_size,num_labels): | |
| super().__init__() | |
| self.dense = nn.Linear(hidden_size, hidden_size) | |
| self.out_proj = nn.Linear(hidden_size, num_labels) | |
| def forward(self, x): | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.out_proj(x) | |
| return x | |
| class TextClassificationModel(nn.Module): | |
| def __init__(self, opt,dim=2): | |
| super(TextClassificationModel, self).__init__() | |
| self.model = TextEmbeddingModel(opt.model_name,lora=True,use_pooling=opt.pooling,\ | |
| lora_r=opt.lora_r,lora_alpha=opt.lora_alpha,infer=True) | |
| self.root_classfier = nn.Linear(opt.embedding_dim, dim) | |
| def forward(self, encoded_batch): | |
| q = self.model(encoded_batch) | |
| out = self.root_classfier(q) | |
| return out | |