| | import os
|
| | from dataclasses import dataclass
|
| | from typing import List, Optional, Union
|
| | import torch
|
| | from omegaconf import DictConfig, OmegaConf
|
| | from torch import nn
|
| | from transformers import (
|
| | AutoModelForCausalLM,
|
| | AutoTokenizer,
|
| | CLIPTextModel,
|
| | CLIPTokenizerFast,
|
| | T5EncoderModel,
|
| | T5TokenizerFast,
|
| | )
|
| | from transformers.tokenization_utils_base import BatchEncoding
|
| |
|
| | from common.fs import download_and_extract
|
| | from common.logger import get_logger
|
| |
|
| | logger = get_logger(__name__)
|
| |
|
| | MODEL_TYPES = {
|
| | "clip": (CLIPTokenizerFast, CLIPTextModel),
|
| | "t5": (T5TokenizerFast, T5EncoderModel),
|
| | "llm14b": (AutoTokenizer, AutoModelForCausalLM),
|
| | }
|
| |
|
| |
|
| | @dataclass
|
| | class TextEncoderOutput:
|
| | embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]]
|
| | masks: Union[torch.BoolTensor, List[torch.BoolTensor]]
|
| | pooled: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]]
|
| |
|
| |
|
| | class TextEncoder(nn.Module):
|
| | def __init__(self, config: DictConfig):
|
| | super().__init__()
|
| | self.config = config
|
| | self.tokenizers = []
|
| | self.models = nn.ModuleList([])
|
| |
|
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| |
|
| | for model in config.models:
|
| | tokenizer_cls, model_cls = MODEL_TYPES[model.type]
|
| | path = download_and_extract(model.path)
|
| | max_length = model.max_length
|
| |
|
| | if model.type == "llm14b":
|
| | tokenizer = tokenizer_cls.from_pretrained(
|
| | path,
|
| | model_max_length=max_length,
|
| | use_fast=False,
|
| | trust_remote_code=True,
|
| | padding_side="right",
|
| | truncation_side="right",
|
| | add_eod_token=True,
|
| | )
|
| | tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
|
| | model = model_cls.from_pretrained(path, trust_remote_code=True, bf16=True)
|
| | else:
|
| | tokenizer = tokenizer_cls.from_pretrained(path, model_max_length=max_length)
|
| | model = model_cls.from_pretrained(path, torch_dtype=torch.bfloat16)
|
| | self.tokenizers.append(tokenizer)
|
| | self.models.append(model)
|
| |
|
| | def forward(self, text: Union[str, List[str]]) -> TextEncoderOutput:
|
| | embeddings, masks, pooled = [], [], []
|
| |
|
| | for encoder_config, tokenizer, model in zip(
|
| | self.config.models, self.tokenizers, self.models
|
| | ):
|
| | if encoder_config.type == "llm14b":
|
| | use_mask = encoder_config.get("mask", True)
|
| | tokens = tokenizer(
|
| | text,
|
| | return_tensors="pt",
|
| | padding="max_length",
|
| | truncation=True,
|
| | ).to(model.device)
|
| | token_ids = tokens["input_ids"]
|
| | attention_mask = tokens["attention_mask"]
|
| | num_tokens = attention_mask.sum(dim=1)
|
| | range_ids = torch.arange(len(token_ids), device=token_ids.device, dtype=torch.long)
|
| | token_ids[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = (
|
| | tokenizer.pad_token_id
|
| | )
|
| | attention_mask[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = 1
|
| | tokens = BatchEncoding({"input_ids": token_ids, "attention_mask": attention_mask})
|
| | output = model.transformer(
|
| | input_ids=tokens.input_ids,
|
| | attention_mask=attention_mask if use_mask else None,
|
| | output_hidden_states=False,
|
| | use_cache=False,
|
| | )
|
| | emb = output.last_hidden_state
|
| |
|
| |
|
| | embeddings.append(emb)
|
| | masks.append(
|
| | tokens.attention_mask.bool() if use_mask else tokens.attention_mask > -1
|
| | )
|
| |
|
| | else:
|
| |
|
| | tokens = tokenizer(
|
| | text=text,
|
| | truncation=True,
|
| | padding="max_length",
|
| | return_tensors="pt",
|
| | )
|
| |
|
| |
|
| | use_mask = encoder_config.get("mask", True)
|
| | input_ids = tokens.input_ids.to(model.device)
|
| | attention_mask = tokens.attention_mask.to(model.device)
|
| | output = model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask if use_mask else None,
|
| | output_hidden_states=True,
|
| | )
|
| |
|
| |
|
| | layer = encoder_config.get("layer", "last")
|
| | if layer == "last":
|
| | embeddings.append(output.last_hidden_state)
|
| | elif layer == "penultimate":
|
| | embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
|
| | elif layer == "penultimate_nonorm":
|
| | embeddings.append(output.hidden_states[-2])
|
| | else:
|
| | raise NotImplementedError(f"Unknown layer type: {layer}.")
|
| |
|
| |
|
| | masks.append(attention_mask.bool() if use_mask else attention_mask > -1)
|
| |
|
| |
|
| | if hasattr(output, "pooler_output"):
|
| | pooled.append(output.pooler_output)
|
| |
|
| | output_config = self.config.get("output") or OmegaConf.create()
|
| | embedding_output_type = output_config.get("embedding_and_mask", "undefined")
|
| | pooled_output_type = output_config.get("pooled", "undefined")
|
| |
|
| |
|
| | if embedding_output_type == "undefined" and len(self.models) == 1:
|
| | embeddings = embeddings[0]
|
| | masks = masks[0]
|
| | elif embedding_output_type == "channel_concat":
|
| | embeddings = torch.cat(embeddings, dim=-1)
|
| | masks = sum(masks).bool()
|
| | elif embedding_output_type == "last":
|
| | embeddings = embeddings[-1]
|
| | masks = masks[-1]
|
| | else:
|
| | raise NotImplementedError(f"output.embedding_and_mask: {embedding_output_type}")
|
| |
|
| |
|
| | if pooled_output_type == "undefined":
|
| | pooled = None
|
| | elif pooled_output_type == "channel_concat":
|
| | pooled = torch.cat(pooled, dim=-1)
|
| | elif pooled_output_type == "first":
|
| | pooled = pooled[0]
|
| | elif pooled_output_type == "last":
|
| | pooled = pooled[-1]
|
| | else:
|
| | raise NotImplementedError(f"output.pooled: {pooled_output_type}")
|
| |
|
| |
|
| | return TextEncoderOutput(embeddings, masks, pooled)
|
| |
|