Shri commited on
Commit
d562e10
·
1 Parent(s): 0cca1ec

fix: json,onnx model id

Browse files
Files changed (1) hide show
  1. src/chatbot/embedding.py +11 -4
src/chatbot/embedding.py CHANGED
@@ -1,4 +1,6 @@
1
  # to run this file you need model.onnx_data on the assets/onnx folder or you can obtain it from here.: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/tree/main/onnx
 
 
2
  import asyncio
3
  import os
4
  from typing import List
@@ -10,7 +12,8 @@ from transformers import AutoTokenizer
10
 
11
  BASE_DIR = os.path.dirname(__file__)
12
 
13
- TOKENIZER_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "assets", "tokenizer"))
 
14
 
15
  # MODEL_DIR = os.path.abspath(
16
  # os.path.join(BASE_DIR, "..", "assets", "onnx", "model.onnx")
@@ -20,9 +23,7 @@ TOKENIZER_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "assets", "tokenize
20
  class EmbeddingModel:
21
  def __init__(self):
22
  # print(TOKENIZER_DIR)
23
- self.tokenizer = AutoTokenizer.from_pretrained(
24
- TOKENIZER_DIR, local_files_only=True
25
- )
26
 
27
  # sess_options = ort.SessionOptions()
28
  # providers = ["CPUExecutionProvider"]
@@ -84,6 +85,12 @@ class EmbeddingModel:
84
  return input_ids.flatten().tolist()
85
 
86
 
 
 
 
 
 
 
87
  embedding_model = EmbeddingModel()
88
 
89
 
 
1
  # to run this file you need model.onnx_data on the assets/onnx folder or you can obtain it from here.: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/tree/main/onnx
2
+ # model can also be loaded directly from autoModel.pretrained by using the same link "onnx-community/embeddinggemma-300m-ONNX"
3
+
4
  import asyncio
5
  import os
6
  from typing import List
 
12
 
13
  BASE_DIR = os.path.dirname(__file__)
14
 
15
+ # TOKENIZER_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "assets", "tokenizer"))
16
+ TOKENIZER_DIR = "onnx-community/embeddinggemma-300m-ONNX"
17
 
18
  # MODEL_DIR = os.path.abspath(
19
  # os.path.join(BASE_DIR, "..", "assets", "onnx", "model.onnx")
 
23
  class EmbeddingModel:
24
  def __init__(self):
25
  # print(TOKENIZER_DIR)
26
+ self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
 
 
27
 
28
  # sess_options = ort.SessionOptions()
29
  # providers = ["CPUExecutionProvider"]
 
85
  return input_ids.flatten().tolist()
86
 
87
 
88
+ def cleanup(self):
89
+ if self.session:
90
+ self.session = None
91
+ print("ONNX runtime session closed.")
92
+
93
+
94
  embedding_model = EmbeddingModel()
95
 
96