| ---
|
| base_model: google/xtr-base-en
|
| license: apache-2.0
|
| tags:
|
| - arxiv:2304.01982
|
| ---
|
|
|
| XTR-ONNX
|
| ---
|
|
|
| This model is google's XTR-base-en model exported to ONNX format.
|
|
|
| original XTR model: https://huggingface.co/google/xtr-base-en
|
|
|
| Given a max length input of 512, this model will output a 128 dimensional vector for each token.
|
|
|
| XTR's demo notebook uses only one special token -- EOS.
|
|
|
| ## Using this model
|
| This model can be plugged into LintDB to index data into a database.
|
|
|
| ### In LintDB
|
| ```python
|
| # create an XTR index
|
| config = ldb.Configuration()
|
| config.num_subquantizers = 64
|
| config.dim = 128
|
| config.nbits = 4
|
| config.quantizer_type = ldb.IndexEncoding_XTR
|
| index = ldb.IndexIVF(f"experiments/goog", config)
|
|
|
| # build a collection on top of the index
|
| opts = ldb.CollectionOptions()
|
| opts.model_file = "assets/xtr/encoder.onnx"
|
| opts.tokenizer_file = "assets/xtr/spiece.model"
|
|
|
| collection = ldb.Collection(index, opts)
|
|
|
| collection.train(chunks, 50, 10)
|
|
|
| for i, snip in enumerate(chunks):
|
| collection.add(0, i, snip, {'docid': f'{i}'})
|
| ```
|
|
|
| ## Creating this model
|
| In order to create this model, we had to combine XTR's T5 encoder model
|
| with a dense layer. Below is the code used to do this. Credit to yaman on Github
|
| for this solution.
|
|
|
| ```python
|
| from sentence_transformers import SentenceTransformer
|
| from sentence_transformers import models
|
| import torch
|
| import torch.nn as nn
|
| import onnx
|
| import numpy as np
|
| from transformers import T5EncoderModel
|
| from pathlib import Path
|
| from transformers import AutoTokenizer
|
|
|
| # https://github.com/huggingface/optimum/issues/1519
|
|
|
| class CombinedModel(nn.Module):
|
| def __init__(self, transformer_model, dense_model):
|
| super(CombinedModel, self).__init__()
|
| self.transformer = transformer_model
|
| self.dense = dense_model
|
|
|
| def forward(self, input_ids, attention_mask):
|
| outputs = self.transformer(input_ids, attention_mask=attention_mask)
|
| token_embeddings = outputs['last_hidden_state']
|
| return self.dense({'sentence_embedding': token_embeddings})
|
|
|
|
|
| save_directory = "onnx/"
|
|
|
| # Load a model from transformers and export it to ONNX
|
| tokenizer = AutoTokenizer.from_pretrained(path)
|
|
|
| # load the t5 base encoder model.
|
| transformer_model = T5EncoderModel.from_pretrained(path)
|
|
|
| dense_model = models.Dense(
|
| in_features=768,
|
| out_features=128,
|
| bias=False,
|
| activation_function= nn.Identity()
|
| )
|
|
|
| state_dict = torch.load(os.path.join(path, '2_Dense', dense_filename))
|
| dense_model.load_state_dict(state_dict)
|
|
|
| model = CombinedModel(transformer_model, dense_model)
|
|
|
| model.eval()
|
|
|
| input_text = "Who founded google"
|
| inputs = tokenizer(input_text, padding='longest', truncation=True, max_length=128, return_tensors='pt')
|
|
|
| input_ids = inputs['input_ids']
|
| attention_mask = inputs['attention_mask']
|
|
|
| torch.onnx.export(
|
| model,
|
| (input_ids, attention_mask),
|
| "combined_model.onnx",
|
| export_params=True,
|
| opset_version=17,
|
| do_constant_folding=True,
|
| input_names = ['input_ids', 'attention_mask'],
|
| output_names = ['contextual'],
|
| dynamic_axes={
|
| 'input_ids': {0 : 'batch_size', 1: 'seq_length'}, # variable length axes
|
| 'attention_mask': {0 : 'batch_size', 1: 'seq_length'},
|
| 'contextual' : {0 : 'batch_size', 1: 'seq_length'}
|
| }
|
| )
|
|
|
| onnx.checker.check_model("combined_model.onnx")
|
|
|
| combined_model = onnx.load("combined_model.onnx")
|
|
|
| import onnxruntime as ort
|
| ort_session = ort.InferenceSession("combined_model.onnx")
|
| output = ort_session.run(None, {'input_ids': input_ids.numpy(), 'attention_mask': attention_mask.numpy()})
|
|
|
|
|
| # Run the PyTorch model
|
| pytorch_output = model(input_ids, attention_mask)
|
| print(pytorch_output['sentence_embedding'])
|
|
|
| print(output[0])
|
| # Compare the outputs
|
| # print("Are the outputs close?", np.allclose(pytorch_output.detach().numpy(), output[0], atol=1e-6))
|
|
|
| # Calculate the differences between the outputs
|
| differences = pytorch_output['sentence_embedding'].detach().numpy() - output[0]
|
|
|
| # Print the standard deviation of the differences
|
| print("Standard deviation of the differences:", np.std(differences))
|
|
|
| print("pytorch_output size:", pytorch_output['sentence_embedding'].size())
|
| print("onnx_output size:", output[0].shape)
|
| ``` |