Update README.md
Browse files
README.md
CHANGED
|
@@ -90,6 +90,86 @@ embeddings = outputs
|
|
| 90 |
print(embeddings)
|
| 91 |
```
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
## Citation
|
| 94 |
|
| 95 |
If you use this model in your work, please cite:
|
|
|
|
| 90 |
print(embeddings)
|
| 91 |
```
|
| 92 |
|
| 93 |
+
### 3. Generate Fixed 768 Dimensions Size Embeddings as output
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
import numpy as np
|
| 97 |
+
import onnxruntime as ort
|
| 98 |
+
from transformers import AutoTokenizer
|
| 99 |
+
from huggingface_hub import hf_hub_download
|
| 100 |
+
|
| 101 |
+
class GemmaEmbedder:
|
| 102 |
+
"""
|
| 103 |
+
A class to generate embeddings using a Gemma model in ONNX format.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, model_repo="be1newinner/embeddinggemma-300m-onnx"):
|
| 107 |
+
"""
|
| 108 |
+
Initializes the GemmaEmbedder by loading the tokenizer and ONNX model.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
model_repo (str): The repository ID of the ONNX model on Hugging Face Hub.
|
| 112 |
+
"""
|
| 113 |
+
# Load the tokenizer from the Hugging Face Hub
|
| 114 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
| 115 |
+
|
| 116 |
+
# Download and load the ONNX model
|
| 117 |
+
onnx_model_path = hf_hub_download(repo_id=model_repo, filename="model.onnx")
|
| 118 |
+
self.session = ort.InferenceSession(onnx_model_path)
|
| 119 |
+
|
| 120 |
+
def generate(self, text: str):
|
| 121 |
+
"""
|
| 122 |
+
Generates a fixed-size embedding for the input text.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
text (str): The input text to embed.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
np.ndarray: The generated embedding as a NumPy array.
|
| 129 |
+
"""
|
| 130 |
+
# Tokenize the input text, padding and truncating to a consistent length
|
| 131 |
+
inputs = self.tokenizer(
|
| 132 |
+
text, return_tensors="np", padding=True, truncation=True
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Run the ONNX model to get the last hidden states
|
| 136 |
+
outputs = self.session.run(None, dict(inputs))
|
| 137 |
+
|
| 138 |
+
# Perform mean pooling to get a fixed-size embedding
|
| 139 |
+
last_hidden_states = outputs[0]
|
| 140 |
+
input_mask_expanded = np.expand_dims(inputs["attention_mask"], -1).astype(float)
|
| 141 |
+
sum_embeddings = np.sum(last_hidden_states * input_mask_expanded, 1)
|
| 142 |
+
sum_mask = np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None)
|
| 143 |
+
|
| 144 |
+
pooled_embeddings = sum_embeddings / sum_mask
|
| 145 |
+
return pooled_embeddings
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Create a global instance of the embedder to avoid reloading the model
|
| 149 |
+
embedder = GemmaEmbedder()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def generate(text: str):
|
| 153 |
+
"""
|
| 154 |
+
A convenience function to generate embeddings using the global embedder instance.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
text (str): The input text to embed.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
np.ndarray: The generated embedding.
|
| 161 |
+
"""
|
| 162 |
+
return embedder.generate(text)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
# Example usage of the generate function
|
| 167 |
+
embeddings = generate("Example input text")
|
| 168 |
+
print(embeddings)
|
| 169 |
+
print(f"Embedding shape: {embeddings.shape}")
|
| 170 |
+
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
## Citation
|
| 174 |
|
| 175 |
If you use this model in your work, please cite:
|