Shri commited on
Commit
0cca1ec
·
1 Parent(s): 06cff93

fix: onnx error fix

Browse files
Files changed (3) hide show
  1. src/chatbot/embedding.py +54 -48
  2. src/main.py +0 -2
  3. src/profile/router.py +0 -3
src/chatbot/embedding.py CHANGED
@@ -1,65 +1,65 @@
1
  # to run this file you need model.onnx_data on the assets/onnx folder or you can obtain it from here.: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/tree/main/onnx
2
-
3
  import asyncio
4
  import os
5
  from typing import List
6
 
7
  import numpy as np
8
- import onnxruntime as ort
 
9
  from transformers import AutoTokenizer
10
 
11
  BASE_DIR = os.path.dirname(__file__)
12
 
13
  TOKENIZER_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "assets", "tokenizer"))
14
 
15
- MODEL_DIR = os.path.abspath(
16
- os.path.join(BASE_DIR, "..", "assets", "onnx", "model.onnx")
17
- )
18
 
19
 
20
  class EmbeddingModel:
21
  def __init__(self):
22
- print(TOKENIZER_DIR)
23
  self.tokenizer = AutoTokenizer.from_pretrained(
24
  TOKENIZER_DIR, local_files_only=True
25
  )
26
 
27
- sess_options = ort.SessionOptions()
28
- providers = ["CPUExecutionProvider"]
29
-
30
- self.session = ort.InferenceSession(
31
- MODEL_DIR, sess_options, providers=providers
32
- )
33
-
34
- self.input_names = [inp.name for inp in self.session.get_inputs()]
35
- self.output_names = [out.name for out in self.session.get_outputs()]
36
-
37
- def _run_sync(
38
- self, input_ids: np.ndarray, attention_mask: np.ndarray
39
- ) -> List[float]:
40
- inputs = {}
41
-
42
- if "input_ids" in self.input_names:
43
- inputs["input_ids"] = input_ids
44
- else:
45
- inputs[self.input_names[0]] = input_ids
46
-
47
- if "attention_mask" in self.input_names:
48
- inputs["attention_mask"] = attention_mask
49
- elif len(self.input_names) > 1:
50
- inputs[self.input_names[1]] = attention_mask
51
-
52
- outputs = self.session.run(self.output_names, inputs)
53
- emb = outputs[0]
54
-
55
- if emb.ndim == 3:
56
- emb_vector = emb.mean(axis=1)[0]
57
- elif emb.ndim == 2:
58
- emb_vector = emb[0]
59
- else:
60
- emb_vector = np.asarray(emb).flatten()
61
-
62
- return emb_vector.astype(float).tolist()
63
 
64
  async def embed_text(self, text: str, max_length: int = 512) -> List[float]:
65
 
@@ -76,12 +76,18 @@ class EmbeddingModel:
76
  np.int64
77
  )
78
 
79
- loop = asyncio.get_event_loop()
80
- vector = await loop.run_in_executor(
81
- None, self._run_sync, input_ids, attention_mask
82
- )
83
-
84
- return vector
85
 
86
 
87
  embedding_model = EmbeddingModel()
 
 
 
 
 
 
 
1
  # to run this file you need model.onnx_data on the assets/onnx folder or you can obtain it from here.: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/tree/main/onnx
 
2
  import asyncio
3
  import os
4
  from typing import List
5
 
6
  import numpy as np
7
+
8
+ # import onnxruntime as ort
9
  from transformers import AutoTokenizer
10
 
11
  BASE_DIR = os.path.dirname(__file__)
12
 
13
  TOKENIZER_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "assets", "tokenizer"))
14
 
15
+ # MODEL_DIR = os.path.abspath(
16
+ # os.path.join(BASE_DIR, "..", "assets", "onnx", "model.onnx")
17
+ # )
18
 
19
 
20
  class EmbeddingModel:
21
  def __init__(self):
22
+ # print(TOKENIZER_DIR)
23
  self.tokenizer = AutoTokenizer.from_pretrained(
24
  TOKENIZER_DIR, local_files_only=True
25
  )
26
 
27
+ # sess_options = ort.SessionOptions()
28
+ # providers = ["CPUExecutionProvider"]
29
+ #
30
+ # self.session = ort.InferenceSession(
31
+ # MODEL_DIR, sess_options, providers=providers
32
+ # )
33
+ #
34
+ # self.input_names = [inp.name for inp in self.session.get_inputs()]
35
+ # self.output_names = [out.name for out in self.session.get_outputs()]
36
+
37
+ # def _run_sync(
38
+ # self, input_ids: np.ndarray, attention_mask: np.ndarray
39
+ # ) -> List[float]:
40
+ # inputs = {}
41
+ #
42
+ # if "input_ids" in self.input_names:
43
+ # inputs["input_ids"] = input_ids
44
+ # else:
45
+ # inputs[self.input_names[0]] = input_ids
46
+ #
47
+ # if "attention_mask" in self.input_names:
48
+ # inputs["attention_mask"] = attention_mask
49
+ # elif len(self.input_names) > 1:
50
+ # inputs[self.input_names[1]] = attention_mask
51
+ #
52
+ # outputs = self.session.run(self.output_names, inputs)
53
+ # emb = outputs[0]
54
+ #
55
+ # if emb.ndim == 3:
56
+ # emb_vector = emb.mean(axis=1)[0]
57
+ # elif emb.ndim == 2:
58
+ # emb_vector = emb[0]
59
+ # else:
60
+ # emb_vector = np.asarray(emb).flatten()
61
+ #
62
+ # return emb_vector.astype(float).tolist()
63
 
64
  async def embed_text(self, text: str, max_length: int = 512) -> List[float]:
65
 
 
76
  np.int64
77
  )
78
 
79
+ # loop = asyncio.get_event_loop()
80
+ # vector = await loop.run_in_executor(
81
+ # None, self._run_sync, input_ids, attention_mask
82
+ # )
83
+ # return vector
84
+ return input_ids.flatten().tolist()
85
 
86
 
87
  embedding_model = EmbeddingModel()
88
+
89
+
90
+ async def test_tokenizer():
91
+ text = "What does the company telll about moonlighting"
92
+ tokens = await embedding_model.embed_text(text)
93
+ print("Tokenized text:", tokens)
src/main.py CHANGED
@@ -1,11 +1,9 @@
1
  from fastapi import FastAPI
2
 
3
- from src.assets.router import router as assets
4
  from src.auth.router import router as auth_router
5
  from src.chatbot.router import router as chatbot
6
  from src.core.database import init_db
7
  from src.home.router import router as home_router
8
- from src.leave.router import router as leave
9
  from src.profile.router import router as profile
10
 
11
  app = FastAPI(title="Yuvabe App API")
 
1
  from fastapi import FastAPI
2
 
 
3
  from src.auth.router import router as auth_router
4
  from src.chatbot.router import router as chatbot
5
  from src.core.database import init_db
6
  from src.home.router import router as home_router
 
7
  from src.profile.router import router as profile
8
 
9
  app = FastAPI(title="Yuvabe App API")
src/profile/router.py CHANGED
@@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends
10
  from sqlmodel.ext.asyncio.session import AsyncSession
11
  from src.core.database import get_async_session
12
  from src.auth.utils import get_current_user
13
- from src.assets.schemas import BaseResponse
14
- from src.assets.service import list_user_assets
15
- from src.leave.utils import send_email
16
  from fastapi import APIRouter, Depends, HTTPException
17
  from sqlmodel import select
18
  from sqlmodel.ext.asyncio.session import AsyncSession
 
10
  from sqlmodel.ext.asyncio.session import AsyncSession
11
  from src.core.database import get_async_session
12
  from src.auth.utils import get_current_user
 
 
 
13
  from fastapi import APIRouter, Depends, HTTPException
14
  from sqlmodel import select
15
  from sqlmodel.ext.asyncio.session import AsyncSession