|
|
--- |
|
|
base_model: |
|
|
- Alibaba-NLP/gte-multilingual-base |
|
|
pipeline_tag: text-generation |
|
|
license: apache-2.0 |
|
|
--- |
|
|
This is the ONNX version of the [gte-multilingual-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-base) model. |
|
|
|
|
|
This example is adapted from the original model repository for the ONNX version. |
|
|
```python |
|
|
# Requires transformers>=4.36.0 |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer |
|
|
input_texts = [ |
|
|
"what is the capital of China?", |
|
|
"how to implement quick sort in python?", |
|
|
"北京", |
|
|
"快排算法介绍" |
|
|
] |
|
|
# Load the tokenizer (using the original model for tokenizer) |
|
|
tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-multilingual-base') |
|
|
# Load the ONNX model |
|
|
session = ort.InferenceSession("model.onnx") |
|
|
# Tokenize the input texts |
|
|
batch_dict = tokenizer(input_texts, max_length=8192, padding=True, truncation=True, return_tensors='np') |
|
|
# Run inference |
|
|
outputs = session.run(None, { |
|
|
"input_ids": batch_dict["input_ids"], |
|
|
"attention_mask": batch_dict["attention_mask"] |
|
|
}) |
|
|
# Get embeddings from the second output (last hidden states) |
|
|
# Extract the [CLS] token embedding (first token) for each sequence |
|
|
last_hidden_states = outputs[1] # Shape: (batch_size, seq_len, hidden_size) |
|
|
dimension = 768 # The output dimension of the output embedding, should be in [128, 768] |
|
|
embeddings = last_hidden_states[:, 0, :dimension] # Shape: (batch_size, dimension) |
|
|
# Debug: Check embeddings |
|
|
print(f"Embeddings shape: {embeddings.shape}") |
|
|
print(f"First few values of first embedding: {embeddings[0][:5]}") |
|
|
print(f"First few values of second embedding: {embeddings[1][:5]}") |
|
|
# Normalize embeddings |
|
|
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
|
# Calculate similarity scores |
|
|
scores = (embeddings[:1] @ embeddings[1:].T) * 100 |
|
|
print(scores.tolist()) |
|
|
# [[0.3016996383666992, 0.7503870129585266, 0.3203084468841553]] |
|
|
``` |