be1newinner commited on
Commit
a320056
·
verified ·
1 Parent(s): 72c9298

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +80 -0
README.md CHANGED
@@ -90,6 +90,86 @@ embeddings = outputs
90
  print(embeddings)
91
  ```
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  ## Citation
94
 
95
  If you use this model in your work, please cite:
 
90
  print(embeddings)
91
  ```
92
 
93
+ ### 3. Generate Fixed 768 Dimensions Size Embeddings as output
94
+
95
+ ```
96
+ import numpy as np
97
+ import onnxruntime as ort
98
+ from transformers import AutoTokenizer
99
+ from huggingface_hub import hf_hub_download
100
+
101
+ class GemmaEmbedder:
102
+ """
103
+ A class to generate embeddings using a Gemma model in ONNX format.
104
+ """
105
+
106
+ def __init__(self, model_repo="be1newinner/embeddinggemma-300m-onnx"):
107
+ """
108
+ Initializes the GemmaEmbedder by loading the tokenizer and ONNX model.
109
+
110
+ Args:
111
+ model_repo (str): The repository ID of the ONNX model on Hugging Face Hub.
112
+ """
113
+ # Load the tokenizer from the Hugging Face Hub
114
+ self.tokenizer = AutoTokenizer.from_pretrained(model_repo)
115
+
116
+ # Download and load the ONNX model
117
+ onnx_model_path = hf_hub_download(repo_id=model_repo, filename="model.onnx")
118
+ self.session = ort.InferenceSession(onnx_model_path)
119
+
120
+ def generate(self, text: str):
121
+ """
122
+ Generates a fixed-size embedding for the input text.
123
+
124
+ Args:
125
+ text (str): The input text to embed.
126
+
127
+ Returns:
128
+ np.ndarray: The generated embedding as a NumPy array.
129
+ """
130
+ # Tokenize the input text, padding and truncating to a consistent length
131
+ inputs = self.tokenizer(
132
+ text, return_tensors="np", padding=True, truncation=True
133
+ )
134
+
135
+ # Run the ONNX model to get the last hidden states
136
+ outputs = self.session.run(None, dict(inputs))
137
+
138
+ # Perform mean pooling to get a fixed-size embedding
139
+ last_hidden_states = outputs[0]
140
+ input_mask_expanded = np.expand_dims(inputs["attention_mask"], -1).astype(float)
141
+ sum_embeddings = np.sum(last_hidden_states * input_mask_expanded, 1)
142
+ sum_mask = np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None)
143
+
144
+ pooled_embeddings = sum_embeddings / sum_mask
145
+ return pooled_embeddings
146
+
147
+
148
+ # Create a global instance of the embedder to avoid reloading the model
149
+ embedder = GemmaEmbedder()
150
+
151
+
152
+ def generate(text: str):
153
+ """
154
+ A convenience function to generate embeddings using the global embedder instance.
155
+
156
+ Args:
157
+ text (str): The input text to embed.
158
+
159
+ Returns:
160
+ np.ndarray: The generated embedding.
161
+ """
162
+ return embedder.generate(text)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ # Example usage of the generate function
167
+ embeddings = generate("Example input text")
168
+ print(embeddings)
169
+ print(f"Embedding shape: {embeddings.shape}")
170
+
171
+ ```
172
+
173
  ## Citation
174
 
175
  If you use this model in your work, please cite: