m97j commited on
Commit
6e89dad
·
1 Parent(s): de65949

Initial codes commit

Browse files
Files changed (3) hide show
  1. models/embedder.py +3 -0
  2. models/reranker.py +3 -0
  3. modules/utils.py +10 -0
models/embedder.py CHANGED
@@ -3,6 +3,7 @@ from typing import List
3
  import numpy as np
4
  import onnxruntime as ort
5
  from fastapi import Request
 
6
 
7
  def _l2_normalize(vec: np.ndarray) -> List[float]:
8
  norm = np.linalg.norm(vec) or 1.0
@@ -17,6 +18,8 @@ def get_embedding(request: Request, text: str) -> List[float]:
17
  sess: ort.InferenceSession = request.app.state.embedder_sess
18
 
19
  inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
 
 
20
  ort_inputs = {k: v for k, v in inputs.items()}
21
  ort_outs = sess.run(None, ort_inputs)
22
  # 일반적으로 첫 번째 출력이 [batch, dim] 임베딩
 
3
  import numpy as np
4
  import onnxruntime as ort
5
  from fastapi import Request
6
+ from modules.utils import generate_position_ids
7
 
8
  def _l2_normalize(vec: np.ndarray) -> List[float]:
9
  norm = np.linalg.norm(vec) or 1.0
 
18
  sess: ort.InferenceSession = request.app.state.embedder_sess
19
 
20
  inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
21
+ input_ids = inputs["input_ids"]
22
+ inputs["position_ids"] = generate_position_ids(input_ids)
23
  ort_inputs = {k: v for k, v in inputs.items()}
24
  ort_outs = sess.run(None, ort_inputs)
25
  # 일반적으로 첫 번째 출력이 [batch, dim] 임베딩
models/reranker.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  from typing import List, Dict
4
  import onnxruntime as ort
5
  from fastapi import Request
 
6
 
7
  THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
8
 
@@ -20,6 +21,8 @@ def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]:
20
 
21
  pairs = [(query, ctx["text"]) for ctx in contexts]
22
  inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
 
 
23
  ort_inputs = {k: v for k, v in inputs.items()}
24
  scores = sess.run(None, ort_inputs)[0] # [batch, 1] 형태
25
  scores = scores.squeeze(-1)
 
3
  from typing import List, Dict
4
  import onnxruntime as ort
5
  from fastapi import Request
6
+ from modules.utils import generate_position_ids
7
 
8
  THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
9
 
 
21
 
22
  pairs = [(query, ctx["text"]) for ctx in contexts]
23
  inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
24
+ input_ids = inputs["input_ids"]
25
+ inputs["position_ids"] = generate_position_ids(input_ids)
26
  ort_inputs = {k: v for k, v in inputs.items()}
27
  scores = sess.run(None, ort_inputs)[0] # [batch, 1] 형태
28
  scores = scores.squeeze(-1)
modules/utils.py CHANGED
@@ -1,5 +1,6 @@
1
  # rag/modules/utils.py
2
  import os
 
3
 
4
  def ensure_dir(path: str):
5
  os.makedirs(path, exist_ok=True)
@@ -10,3 +11,12 @@ def touch(path: str):
10
 
11
  def exists(path: str) -> bool:
12
  return os.path.exists(path)
 
 
 
 
 
 
 
 
 
 
1
  # rag/modules/utils.py
2
  import os
3
+ import numpy as np
4
 
5
  def ensure_dir(path: str):
6
  os.makedirs(path, exist_ok=True)
 
11
 
12
  def exists(path: str) -> bool:
13
  return os.path.exists(path)
14
+
15
+ def generate_position_ids(input_ids: np.ndarray) -> np.ndarray:
16
+ """
17
+ input_ids: [batch_size, seq_len]
18
+ return: position_ids of shape [batch_size, seq_len] with int64 dtype
19
+ """
20
+ batch_size, seq_len = input_ids.shape
21
+ position_ids = np.arange(seq_len)[None, :].astype("int64")
22
+ return np.tile(position_ids, (batch_size, 1))