m97j commited on
Commit
5f612cd
Β·
1 Parent(s): 6aaa57e

Initial codes commit

Browse files
Files changed (3) hide show
  1. models/embedder.py +10 -2
  2. models/reranker.py +0 -3
  3. modules/utils.py +0 -10
models/embedder.py CHANGED
@@ -3,12 +3,20 @@ from typing import List
3
  import numpy as np
4
  import onnxruntime as ort
5
  from fastapi import Request
6
- from modules.utils import generate_position_ids
7
 
8
  def _l2_normalize(vec: np.ndarray) -> List[float]:
9
  norm = np.linalg.norm(vec) or 1.0
10
  return (vec / norm).tolist()
11
 
 
 
 
 
 
 
 
 
 
12
  def get_embedding(request: Request, text: str) -> List[float]:
13
  """
14
  request.app.state.embedder_sess : ONNX Runtime InferenceSession
@@ -19,7 +27,7 @@ def get_embedding(request: Request, text: str) -> List[float]:
19
 
20
  inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
21
  input_ids = inputs["input_ids"]
22
- inputs["position_ids"] = generate_position_ids(input_ids)
23
  ort_inputs = {k: v for k, v in inputs.items()}
24
  ort_outs = sess.run(None, ort_inputs)
25
  print([arr.shape for arr in ort_outs])
 
3
  import numpy as np
4
  import onnxruntime as ort
5
  from fastapi import Request
 
6
 
7
  def _l2_normalize(vec: np.ndarray) -> List[float]:
8
  norm = np.linalg.norm(vec) or 1.0
9
  return (vec / norm).tolist()
10
 
11
+ def _generate_position_ids(input_ids: np.ndarray) -> np.ndarray:
12
+ """
13
+ input_ids: [batch_size, seq_len]
14
+ return: position_ids of shape [batch_size, seq_len] with int64 dtype
15
+ """
16
+ batch_size, seq_len = input_ids.shape
17
+ position_ids = np.arange(seq_len)[None, :].astype("int64")
18
+ return np.tile(position_ids, (batch_size, 1))
19
+
20
  def get_embedding(request: Request, text: str) -> List[float]:
21
  """
22
  request.app.state.embedder_sess : ONNX Runtime InferenceSession
 
27
 
28
  inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
29
  input_ids = inputs["input_ids"]
30
+ inputs["position_ids"] = _generate_position_ids(input_ids)
31
  ort_inputs = {k: v for k, v in inputs.items()}
32
  ort_outs = sess.run(None, ort_inputs)
33
  print([arr.shape for arr in ort_outs])
models/reranker.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  from typing import List, Dict
4
  import onnxruntime as ort
5
  from fastapi import Request
6
- from modules.utils import generate_position_ids
7
 
8
  THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
9
 
@@ -21,8 +20,6 @@ def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]:
21
 
22
  pairs = [(query, ctx["text"]) for ctx in contexts]
23
  inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
24
- input_ids = inputs["input_ids"]
25
- inputs["position_ids"] = generate_position_ids(input_ids)
26
  ort_inputs = {k: v for k, v in inputs.items()}
27
  scores = sess.run(None, ort_inputs)[0] # [batch, 1] ν˜•νƒœ
28
  scores = scores.squeeze(-1)
 
3
  from typing import List, Dict
4
  import onnxruntime as ort
5
  from fastapi import Request
 
6
 
7
  THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
8
 
 
20
 
21
  pairs = [(query, ctx["text"]) for ctx in contexts]
22
  inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
 
 
23
  ort_inputs = {k: v for k, v in inputs.items()}
24
  scores = sess.run(None, ort_inputs)[0] # [batch, 1] ν˜•νƒœ
25
  scores = scores.squeeze(-1)
modules/utils.py CHANGED
@@ -1,6 +1,5 @@
1
  # rag/modules/utils.py
2
  import os
3
- import numpy as np
4
 
5
  def ensure_dir(path: str):
6
  os.makedirs(path, exist_ok=True)
@@ -11,12 +10,3 @@ def touch(path: str):
11
 
12
  def exists(path: str) -> bool:
13
  return os.path.exists(path)
14
-
15
- def generate_position_ids(input_ids: np.ndarray) -> np.ndarray:
16
- """
17
- input_ids: [batch_size, seq_len]
18
- return: position_ids of shape [batch_size, seq_len] with int64 dtype
19
- """
20
- batch_size, seq_len = input_ids.shape
21
- position_ids = np.arange(seq_len)[None, :].astype("int64")
22
- return np.tile(position_ids, (batch_size, 1))
 
1
  # rag/modules/utils.py
2
  import os
 
3
 
4
  def ensure_dir(path: str):
5
  os.makedirs(path, exist_ok=True)
 
10
 
11
  def exists(path: str) -> bool:
12
  return os.path.exists(path)