| import torch | |
| from transformers import ( | |
| AutoTokenizer, AutoModel, | |
| BertTokenizer, BertModel, | |
| CLIPTokenizer, CLIPTextModel | |
| ) | |
| import torch.nn as nn | |
| import pytorch_lightning as pl | |
| from typing import List | |
| from abc import ABC, abstractmethod | |
| import random | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def taglist_index_to_sentence(taglist_vocab, tag_vocab, taglist_indices, subsample: bool = True): | |
| """ | |
| Convert a tensor or list of taglist indices to a list of tag sentences. | |
| Optionally, randomly shuffle and sample a subset of tags for each sentence. | |
| Args: | |
| taglist_vocab: List of tuples of tag IDs. | |
| tag_vocab: Dictionary mapping tag ID to tag string. | |
| taglist_indices: Tensor or list of indices into taglist_vocab. | |
| seed: Random seed for reproducibility. | |
| subsample: If True, randomly subsample tags in each sentence. | |
| Returns: | |
| tag_sentences: List of strings (tag sentences). | |
| """ | |
| if isinstance(taglist_indices, torch.Tensor): | |
| taglist_indices = taglist_indices.view(-1).tolist() | |
| tag_sentences = [] | |
| for idx in taglist_indices: | |
| tag_ids = taglist_vocab[idx] | |
| tags = [tag_vocab[tid].lower().replace('=', ' ') for tid in tag_ids] | |
| if subsample and len(tags) > 1: | |
| n_sample = random.randint(1, len(tags)) # Choose how many tags to keep | |
| tags = random.sample(tags, n_sample) # Sample without replacement | |
| random.shuffle(tags) # Randomize order | |
| sentence = ' '.join(tags) | |
| tag_sentences.append(sentence) | |
| return tag_sentences | |
| def average_pool(last_hidden_states, attention_mask): | |
| masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True) | |
| class BaseTextEncoder(nn.Module, ABC): | |
| def __init__(self, model_name: str): | |
| super().__init__() | |
| self.model_name = model_name | |
| self.tokenizer = None | |
| self.model = None | |
| self.embedding_dim = None | |
| def encode(self, sentences: List[str], device: str = 'cpu') -> torch.Tensor: | |
| """ | |
| Encode a list of sentences into a tensor of embeddings. | |
| Must be implemented by subclasses. | |
| """ | |
| pass | |
| class BertTextEncoder(BaseTextEncoder): | |
| def __init__(self, model_name='bert-base-uncased'): | |
| super().__init__(model_name) | |
| self.tokenizer = BertTokenizer.from_pretrained(model_name) | |
| self.model = BertModel.from_pretrained(model_name) | |
| self.embedding_dim = self.model.config.hidden_size | |
| def encode(self, sentences, device='cpu'): | |
| self.model.to(device) | |
| inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device) | |
| return self.model(**inputs).pooler_output | |
| class CLIPTextEncoder(BaseTextEncoder): | |
| def __init__(self, model_name='openai/clip-vit-large-patch14', local_tokenizer_path=None): | |
| super().__init__(model_name) | |
| local_tokenizer_path = "/u/cherd/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41" | |
| if local_tokenizer_path is not None: | |
| self.tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) | |
| self.model = CLIPTextModel.from_pretrained(local_tokenizer_path) | |
| else: | |
| self.tokenizer = CLIPTokenizer.from_pretrained(model_name) | |
| self.model = CLIPTextModel.from_pretrained(model_name, from_flax=True) | |
| self.embedding_dim = self.model.config.hidden_size | |
| def encode(self, sentences, device='cpu'): | |
| self.model.to(device) | |
| inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device) | |
| input_ids = inputs['input_ids'] | |
| eos_token_id = self.tokenizer.eos_token_id | |
| pad_token_id = self.tokenizer.pad_token_id | |
| outputs = self.model(**inputs) | |
| last_hidden = outputs.last_hidden_state # [B, T, D] | |
| batch_size = input_ids.size(0) | |
| embeddings = [] | |
| for i in range(batch_size): | |
| input_seq = input_ids[i] | |
| eos_positions = (input_seq == eos_token_id).nonzero(as_tuple=True)[0] | |
| if len(eos_positions) > 0: | |
| eos_idx = eos_positions[-1] # take last EOS (safe for duplicates) | |
| else: | |
| eos_idx = (input_seq != pad_token_id).sum() - 1 # fallback to last non-padding token | |
| embeddings.append(last_hidden[i, eos_idx, :]) | |
| return torch.stack(embeddings) | |
| class E5TextEncoder(BaseTextEncoder): | |
| def __init__(self, model_name='intfloat/e5-base'): | |
| super().__init__(model_name) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.model.pooler = None | |
| self.embedding_dim = self.model.config.hidden_size | |
| def encode(self, sentences, device='cpu'): | |
| self.model.to(device) | |
| sentences = [f"query: {s}" for s in sentences] # official prompt for e5 (for features as per documentation) | |
| inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device) | |
| outputs = self.model(**inputs) | |
| return average_pool(outputs.last_hidden_state, inputs['attention_mask']) | |
| class GritLMTextEncoder(BaseTextEncoder): | |
| def __init__(self, model_name='nomic-ai/nomic-bert-base-punc'): | |
| super().__init__(model_name) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.embedding_dim = self.model.config.hidden_size | |
| self.proj_head = nn.Linear(self.embedding_dim, 768) # to match other encoders | |
| def encode(self, sentences, device='cpu'): | |
| self.model.to(device) | |
| inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device) | |
| outputs = self.model(**inputs) | |
| pooled = average_pool(outputs.last_hidden_state, inputs['attention_mask']) | |
| return self.proj_head(pooled) | |
| class TextEncoder(pl.LightningModule): | |
| def __init__(self, taglist_vocab: List[tuple], tag_vocab: dict, model_name='bert'): | |
| super().__init__() | |
| self.taglist_vocab = taglist_vocab | |
| self.tag_vocab = tag_vocab | |
| model_name = model_name.lower() | |
| encoder_map = { | |
| 'bert': lambda: BertTextEncoder('bert-base-uncased'), | |
| 'clip': lambda: CLIPTextEncoder('openai/clip-vit-large-patch14'), | |
| 'e5': lambda: E5TextEncoder('intfloat/e5-base'), | |
| 'gritlm': lambda: GritLMTextEncoder('nomic-ai/nomic-bert-base-punc') | |
| } | |
| if model_name not in encoder_map: | |
| raise ValueError(f"Unsupported model_name: {model_name}. Choose from {list(encoder_map.keys())}") | |
| print(f"Text backbone: {model_name}") | |
| self.encoder = encoder_map[model_name]() # Instantiate the selected encoder | |
| # self.embedding_dim = 768 | |
| def forward(self, taglist_tensor: torch.Tensor) -> torch.Tensor: | |
| tag_indices = taglist_tensor.tolist() | |
| tag_sentences = taglist_index_to_sentence(self.taglist_vocab, self.tag_vocab, tag_indices, subsample=True) # randomize subsampling tags | |
| embeddings = self.encoder.encode(tag_sentences, device=self.device) | |
| return embeddings | |
| def encode_raw_text(self, raw_text: str) -> torch.Tensor: | |
| """ | |
| Encode a single raw string into an embedding for queries | |
| """ | |
| return self.encoder.encode([raw_text], device=self.device)[0] | |
| def encode_batch(self, raw_texts: List[str]) -> torch.Tensor: | |
| """ | |
| Encode a batch of raw strings into embeddings for queries | |
| """ | |
| return self.encoder.encode(raw_texts, device=self.device) | |
| # import torch | |
| # from transformers import ( | |
| # AutoTokenizer, AutoModel, | |
| # BertTokenizer, BertModel, | |
| # CLIPTokenizer, CLIPTextModel | |
| # ) | |
| # import torch.nn as nn | |
| # import pytorch_lightning as pl | |
| # from typing import List | |
| # from abc import ABC, abstractmethod | |
| # import random | |
| # import os | |
| # os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # def taglist_index_to_sentence(taglist_vocab, tag_vocab, taglist_indices, subsample: bool = True): | |
| # """ | |
| # Convert a tensor or list of taglist indices to a list of tag sentences. | |
| # Optionally, randomly shuffle and sample a subset of tags for each sentence. | |
| # Args: | |
| # taglist_vocab: List of tuples of tag IDs. | |
| # tag_vocab: Dictionary mapping tag ID to tag string. | |
| # taglist_indices: Tensor or list of indices into taglist_vocab. | |
| # seed: Random seed for reproducibility. | |
| # subsample: If True, randomly subsample tags in each sentence. | |
| # Returns: | |
| # tag_sentences: List of strings (tag sentences). | |
| # """ | |
| # if isinstance(taglist_indices, torch.Tensor): | |
| # taglist_indices = taglist_indices.view(-1).tolist() | |
| # tag_sentences = [] | |
| # for idx in taglist_indices: | |
| # tag_ids = taglist_vocab[idx] | |
| # tags = [tag_vocab[tid].lower().replace('=', ' ') for tid in tag_ids] | |
| # if subsample and len(tags) > 1: | |
| # n_sample = random.randint(1, len(tags)) # Choose how many tags to keep | |
| # tags = random.sample(tags, n_sample) # Sample without replacement | |
| # random.shuffle(tags) # Randomize order | |
| # sentence = ' '.join(tags) | |
| # tag_sentences.append(sentence) | |
| # return tag_sentences | |
| # def average_pool(last_hidden_states, attention_mask): | |
| # masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| # return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True) | |
| # class BaseTextEncoder(nn.Module, ABC): | |
| # def __init__(self, model_name: str): | |
| # super().__init__() | |
| # self.model_name = model_name | |
| # self.tokenizer = None | |
| # self.model = None | |
| # self.embedding_dim = None | |
| # @abstractmethod | |
| # def encode(self, sentences: List[str], device: str = 'cpu') -> torch.Tensor: | |
| # """ | |
| # Encode a list of sentences into a tensor of embeddings. | |
| # Must be implemented by subclasses. | |
| # """ | |
| # pass | |
| # class BertTextEncoder(BaseTextEncoder): | |
| # def __init__(self, model_name='bert-base-uncased'): | |
| # super().__init__(model_name) | |
| # self.tokenizer = BertTokenizer.from_pretrained(model_name) | |
| # self.model = BertModel.from_pretrained(model_name) | |
| # self.embedding_dim = self.model.config.hidden_size | |
| # def encode(self, sentences, device='cpu'): | |
| # self.model.to(device) | |
| # inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device) | |
| # return self.model(**inputs).pooler_output | |
| # class CLIPTextEncoder(BaseTextEncoder): | |
| # def __init__(self, model_name='openai/clip-vit-large-patch14'): | |
| # super().__init__(model_name) | |
| # self.tokenizer = CLIPTokenizer.from_pretrained(model_name) | |
| # self.model = CLIPTextModel.from_pretrained(model_name) | |
| # self.embedding_dim = self.model.config.hidden_size | |
| # def encode(self, sentences, device='cpu'): | |
| # self.model.to(device) | |
| # inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device) | |
| # input_ids = inputs['input_ids'] | |
| # eos_token_id = self.tokenizer.eos_token_id | |
| # pad_token_id = self.tokenizer.pad_token_id | |
| # outputs = self.model(**inputs) | |
| # last_hidden = outputs.last_hidden_state # [B, T, D] | |
| # batch_size = input_ids.size(0) | |
| # embeddings = [] | |
| # for i in range(batch_size): | |
| # input_seq = input_ids[i] | |
| # eos_positions = (input_seq == eos_token_id).nonzero(as_tuple=True)[0] | |
| # if len(eos_positions) > 0: | |
| # eos_idx = eos_positions[-1] # take last EOS (safe for duplicates) | |
| # else: | |
| # eos_idx = (input_seq != pad_token_id).sum() - 1 # fallback to last non-padding token | |
| # embeddings.append(last_hidden[i, eos_idx, :]) | |
| # return torch.stack(embeddings) | |
| # class E5TextEncoder(BaseTextEncoder): | |
| # def __init__(self, model_name='intfloat/e5-base'): | |
| # super().__init__(model_name) | |
| # self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # self.model = AutoModel.from_pretrained(model_name) | |
| # self.model.pooler = None | |
| # self.embedding_dim = self.model.config.hidden_size | |
| # def encode(self, sentences, device='cpu'): | |
| # self.model.to(device) | |
| # sentences = [f"query: {s}" for s in sentences] # official prompt for e5 (for features as per documentation) | |
| # inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device) | |
| # outputs = self.model(**inputs) | |
| # return average_pool(outputs.last_hidden_state, inputs['attention_mask']) | |
| # class GritLMTextEncoder(BaseTextEncoder): | |
| # def __init__(self, model_name='nomic-ai/nomic-bert-base-punc'): | |
| # super().__init__(model_name) | |
| # self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # self.model = AutoModel.from_pretrained(model_name) | |
| # self.embedding_dim = self.model.config.hidden_size | |
| # self.proj_head = nn.Linear(self.embedding_dim, 768) # to match other encoders | |
| # def encode(self, sentences, device='cpu'): | |
| # self.model.to(device) | |
| # inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device) | |
| # outputs = self.model(**inputs) | |
| # pooled = average_pool(outputs.last_hidden_state, inputs['attention_mask']) | |
| # return self.proj_head(pooled) | |
| # class TextEncoder(pl.LightningModule): | |
| # def __init__(self, taglist_vocab: List[tuple], tag_vocab: dict, model_name='bert'): | |
| # super().__init__() | |
| # self.taglist_vocab = taglist_vocab | |
| # self.tag_vocab = tag_vocab | |
| # model_name = model_name.lower() | |
| # encoder_map = { | |
| # 'bert': lambda: BertTextEncoder('bert-base-uncased'), | |
| # 'clip': lambda: CLIPTextEncoder('openai/clip-vit-large-patch14'), | |
| # 'e5': lambda: E5TextEncoder('intfloat/e5-base'), | |
| # 'gritlm': lambda: GritLMTextEncoder('nomic-ai/nomic-bert-base-punc') | |
| # } | |
| # if model_name not in encoder_map: | |
| # raise ValueError(f"Unsupported model_name: {model_name}. Choose from {list(encoder_map.keys())}") | |
| # print(f"Text backbone: {model_name}") | |
| # self.encoder = encoder_map[model_name]() # Instantiate the selected encoder | |
| # # self.embedding_dim = 768 | |
| # def forward(self, taglist_tensor: torch.Tensor) -> torch.Tensor: | |
| # tag_indices = taglist_tensor.tolist() | |
| # tag_sentences = taglist_index_to_sentence(self.taglist_vocab, self.tag_vocab, tag_indices, subsample=True) # randomize subsampling tags | |
| # embeddings = self.encoder.encode(tag_sentences, device=self.device) | |
| # return embeddings | |
| # def encode_raw_text(self, raw_text: str) -> torch.Tensor: | |
| # """ | |
| # Encode a single raw string into an embedding for queries | |
| # """ | |
| # return self.encoder.encode([raw_text], device=self.device)[0] |