Shri commited on
Commit
998ba81
·
1 Parent(s): 5ff7281

feat: chunk retrieval updated

Browse files
src/chatbot/embedding.py CHANGED
@@ -1,100 +1,71 @@
1
- # to run this file you need model.onnx_data on the assets/onnx folder or you can obtain it from here.: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/tree/main/onnx
2
- # model can also be loaded directly from autoModel.pretrained by using the same link "onnx-community/embeddinggemma-300m-ONNX"
3
-
4
- import asyncio
5
  import os
 
6
  from typing import List
 
 
 
7
 
8
- import numpy as np
9
 
10
- # import onnxruntime as ort
11
- from transformers import AutoTokenizer
 
 
12
 
13
- BASE_DIR = os.path.dirname(__file__)
 
 
 
 
 
 
 
 
 
14
 
15
- # TOKENIZER_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "assets", "tokenizer"))
16
- TOKENIZER_DIR = "onnx-community/embeddinggemma-300m-ONNX"
17
 
18
- # MODEL_DIR = os.path.abspath(
19
- # os.path.join(BASE_DIR, "..", "assets", "onnx", "model.onnx")
20
- # )
 
 
21
 
 
 
22
 
23
- class EmbeddingModel:
24
- def __init__(self):
25
- # print(TOKENIZER_DIR)
26
- self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
27
-
28
- # sess_options = ort.SessionOptions()
29
- # providers = ["CPUExecutionProvider"]
30
- #
31
- # self.session = ort.InferenceSession(
32
- # MODEL_DIR, sess_options, providers=providers
33
- # )
34
- #
35
- # self.input_names = [inp.name for inp in self.session.get_inputs()]
36
- # self.output_names = [out.name for out in self.session.get_outputs()]
37
-
38
- # def _run_sync(
39
- # self, input_ids: np.ndarray, attention_mask: np.ndarray
40
- # ) -> List[float]:
41
- # inputs = {}
42
- #
43
- # if "input_ids" in self.input_names:
44
- # inputs["input_ids"] = input_ids
45
- # else:
46
- # inputs[self.input_names[0]] = input_ids
47
- #
48
- # if "attention_mask" in self.input_names:
49
- # inputs["attention_mask"] = attention_mask
50
- # elif len(self.input_names) > 1:
51
- # inputs[self.input_names[1]] = attention_mask
52
- #
53
- # outputs = self.session.run(self.output_names, inputs)
54
- # emb = outputs[0]
55
- #
56
- # if emb.ndim == 3:
57
- # emb_vector = emb.mean(axis=1)[0]
58
- # elif emb.ndim == 2:
59
- # emb_vector = emb[0]
60
- # else:
61
- # emb_vector = np.asarray(emb).flatten()
62
- #
63
- # return emb_vector.astype(float).tolist()
64
-
65
- async def embed_text(self, text: str, max_length: int = 512) -> List[float]:
66
 
67
  encoded = self.tokenizer(
68
  text,
69
- return_tensors="np",
70
  truncation=True,
71
- padding="longest",
72
  max_length=max_length,
 
73
  )
74
 
75
  input_ids = encoded["input_ids"].astype(np.int64)
76
- attention_mask = encoded.get("attention_mask", np.ones_like(input_ids)).astype(
77
- np.int64
 
 
 
 
 
 
78
  )
 
79
 
80
- # loop = asyncio.get_event_loop()
81
- # vector = await loop.run_in_executor(
82
- # None, self._run_sync, input_ids, attention_mask
83
- # )
84
- # return vector
85
- return input_ids.flatten().tolist()
86
-
87
 
88
- def cleanup(self):
89
- if self.session:
90
- self.session = None
91
- print("ONNX runtime session closed.")
92
 
 
 
 
93
 
94
- embedding_model = EmbeddingModel()
95
 
96
 
97
- async def test_tokenizer():
98
- text = "What does the company telll about moonlighting"
99
- tokens = await embedding_model.embed_text(text)
100
- print("Tokenized text:", tokens)
 
 
 
 
 
1
  import os
2
+ import numpy as np
3
  from typing import List
4
+ import onnxruntime as ort
5
+ from transformers import AutoTokenizer
6
+ from huggingface_hub import hf_hub_download
7
 
8
+ MODEL_ID = "onnx-community/embeddinggemma-300m-ONNX"
9
 
10
+ class EmbeddingModel:
11
+ def __init__(self):
12
+ print("🔵 Loading tokenizer…")
13
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
14
 
15
+ print("🔵 Downloading ONNX model files…")
16
+
17
+ self.model_path = hf_hub_download(
18
+ repo_id=MODEL_ID,
19
+ filename="onnx/model.onnx"
20
+ )
21
+ self.data_path = hf_hub_download(
22
+ repo_id=MODEL_ID,
23
+ filename="onnx/model.onnx_data"
24
+ )
25
 
26
+ model_dir = os.path.dirname(self.model_path)
 
27
 
28
+ print("🔵 Creating inference session…")
29
+ self.session = ort.InferenceSession(
30
+ self.model_path,
31
+ providers=["CPUExecutionProvider"],
32
+ )
33
 
34
+ self.input_names = [i.name for i in self.session.get_inputs()]
35
+ self.output_names = [o.name for o in self.session.get_outputs()]
36
 
37
+ async def embed_text(self, text: str, max_length=512) -> List[float]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  encoded = self.tokenizer(
40
  text,
 
41
  truncation=True,
42
+ padding=True,
43
  max_length=max_length,
44
+ return_tensors="np",
45
  )
46
 
47
  input_ids = encoded["input_ids"].astype(np.int64)
48
+ attention_mask = encoded["attention_mask"].astype(np.int64)
49
+
50
+ outputs = self.session.run(
51
+ self.output_names,
52
+ {
53
+ self.input_names[0]: input_ids,
54
+ self.input_names[1]: attention_mask,
55
+ },
56
  )
57
+ last_hidden = outputs[0]
58
 
59
+ mask = attention_mask[..., None]
60
+ pooled = (last_hidden * mask).sum(axis=1) / mask.sum(axis=1)
 
 
 
 
 
61
 
62
+ vec = pooled[0]
 
 
 
63
 
64
+ norm = np.linalg.norm(vec)
65
+ if norm > 0:
66
+ vec = vec / norm
67
 
68
+ return vec.tolist()
69
 
70
 
71
+ embedding_model = EmbeddingModel()
 
 
 
src/chatbot/router.py CHANGED
@@ -8,7 +8,8 @@ from sqlalchemy import text
8
  from sqlmodel.ext.asyncio.session import AsyncSession
9
 
10
  from src.core.database import get_async_session
11
-
 
12
  from .embedding import embedding_model
13
  from .schemas import (
14
  SemanticSearchRequest,
@@ -21,42 +22,6 @@ from .service import process_pdf_and_store
21
 
22
  router = APIRouter(prefix="/chatbot", tags=["chatbot"])
23
 
24
-
25
- # before hitting this endpoint make sure the model.data & model.onnx_data is available on the asset/onnx folder
26
- @router.post("/upload-pdf", response_model=UploadKBResponse)
27
- async def upload_pdf(
28
- file: UploadFile = File(...),
29
- name: str = Form(...),
30
- description: Optional[str] = Form(None),
31
- session: AsyncSession = Depends(get_async_session),
32
- ):
33
- if not file.filename.endswith(".pdf"):
34
- raise HTTPException(
35
- status_code=400, detail="Only PDF files are supported for now."
36
- )
37
-
38
- tmp_dir = tempfile.mkdtemp()
39
- tmp_path = os.path.join(tmp_dir, file.filename)
40
- try:
41
- with open(tmp_path, "wb") as out_f:
42
- shutil.copyfileobj(file.file, out_f)
43
-
44
- with open(tmp_path, "rb") as fobj:
45
- result = await process_pdf_and_store(fobj, name, description, session)
46
-
47
- return UploadKBResponse(
48
- kb_id=result["kb_id"],
49
- name=result["name"],
50
- chunks_stored=result["chunks_stored"],
51
- )
52
- finally:
53
- try:
54
- os.remove(tmp_path)
55
- os.rmdir(tmp_dir)
56
- except Exception:
57
- pass
58
-
59
-
60
  @router.post("/tokenize", response_model=TokenizeResponse)
61
  async def tokenize_text(payload: TokenizeRequest):
62
  try:
@@ -88,14 +53,14 @@ async def semantic_search(
88
  q_vector = payload.embedding
89
  top_k = payload.top_k or 3
90
 
91
- # Convert Python list → pgvector string format
92
  q_vector_str = "[" + ",".join(str(x) for x in q_vector) + "]"
93
 
94
  sql = text(
95
  """
96
- SELECT id, kb_id, chunk_text, embedding <=> :query_vec AS score
 
97
  FROM knowledge_chunk
98
- ORDER BY embedding <=> :query_vec
99
  LIMIT :top_k
100
  """
101
  )
@@ -104,7 +69,7 @@ async def semantic_search(
104
  sql, {"query_vec": q_vector_str, "top_k": top_k}
105
  )
106
  rows = result.fetchall()
107
-
108
  return [
109
  SemanticSearchResult(
110
  chunk_id=str(r.id),
@@ -115,3 +80,79 @@ async def semantic_search(
115
  for r in rows
116
  ]
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from sqlmodel.ext.asyncio.session import AsyncSession
9
 
10
  from src.core.database import get_async_session
11
+ from .schemas import ManualTextRequest
12
+ from .service import store_manual_text
13
  from .embedding import embedding_model
14
  from .schemas import (
15
  SemanticSearchRequest,
 
22
 
23
  router = APIRouter(prefix="/chatbot", tags=["chatbot"])
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @router.post("/tokenize", response_model=TokenizeResponse)
26
  async def tokenize_text(payload: TokenizeRequest):
27
  try:
 
53
  q_vector = payload.embedding
54
  top_k = payload.top_k or 3
55
 
 
56
  q_vector_str = "[" + ",".join(str(x) for x in q_vector) + "]"
57
 
58
  sql = text(
59
  """
60
+ SELECT id, kb_id, chunk_text,
61
+ embedding <#> :query_vec AS score
62
  FROM knowledge_chunk
63
+ ORDER BY embedding <#> :query_vec ASC
64
  LIMIT :top_k
65
  """
66
  )
 
69
  sql, {"query_vec": q_vector_str, "top_k": top_k}
70
  )
71
  rows = result.fetchall()
72
+
73
  return [
74
  SemanticSearchResult(
75
  chunk_id=str(r.id),
 
80
  for r in rows
81
  ]
82
 
83
+ # before hitting this endpoint make sure the model.data & model.onnx_data is available on the asset/onnx folder
84
+ # @router.post("/upload-pdf", response_model=UploadKBResponse)
85
+ # async def upload_pdf(
86
+ # file: UploadFile = File(...),
87
+ # name: str = Form(...),
88
+ # description: Optional[str] = Form(None),
89
+ # session: AsyncSession = Depends(get_async_session),
90
+ # ):
91
+ # if not file.filename.endswith(".pdf"):
92
+ # raise HTTPException(
93
+ # status_code=400, detail="Only PDF files are supported for now."
94
+ # )
95
+
96
+ # tmp_dir = tempfile.mkdtemp()
97
+ # tmp_path = os.path.join(tmp_dir, file.filename)
98
+ # try:
99
+ # with open(tmp_path, "wb") as out_f:
100
+ # shutil.copyfileobj(file.file, out_f)
101
+
102
+ # with open(tmp_path, "rb") as fobj:
103
+ # result = await process_pdf_and_store(fobj, name, description, session)
104
+
105
+ # return UploadKBResponse(
106
+ # kb_id=result["kb_id"],
107
+ # name=result["name"],
108
+ # chunks_stored=result["chunks_stored"],
109
+ # )
110
+ # finally:
111
+ # try:
112
+ # os.remove(tmp_path)
113
+ # os.rmdir(tmp_dir)
114
+ # except Exception:
115
+ # pass
116
+
117
+ # @router.post("/manual-add-chunk")
118
+ # async def manual_add_chunk(
119
+ # payload: ManualTextRequest,
120
+ # session: AsyncSession = Depends(get_async_session)
121
+ # ):
122
+ # return await store_manual_text(
123
+ # kb_id=payload.kb_id,
124
+ # text=payload.text,
125
+ # session=session
126
+ # )
127
+
128
+ # @router.post("/test-semantic", response_model=list[SemanticSearchResult])
129
+ # async def test_semantic(
130
+ # query: str,
131
+ # top_k: int = 3,
132
+ # session: AsyncSession = Depends(get_async_session)
133
+ # ):
134
+
135
+ # embedding = await embedding_model.embed_text(query)
136
+
137
+ # q_vec = "[" + ",".join(map(str, embedding)) + "]"
138
+
139
+ # sql = text("""
140
+ # SELECT id, kb_id, chunk_text,
141
+ # embedding <#> :vec AS score
142
+ # FROM knowledge_chunk
143
+ # ORDER BY embedding <#> :vec ASC
144
+ # LIMIT :k
145
+ # """)
146
+
147
+ # result = await session.execute(sql, {"vec": q_vec, "k": top_k})
148
+ # rows = result.fetchall()
149
+
150
+ # return [
151
+ # SemanticSearchResult(
152
+ # chunk_id=str(r.id),
153
+ # kb_id=str(r.kb_id),
154
+ # text=r.chunk_text,
155
+ # score=float(r.score),
156
+ # )
157
+ # for r in rows
158
+ # ]
src/chatbot/schemas.py CHANGED
@@ -34,3 +34,7 @@ class SemanticSearchResult(BaseModel):
34
  kb_id: str
35
  text: str
36
  score: float
 
 
 
 
 
34
  kb_id: str
35
  text: str
36
  score: float
37
+
38
+ class ManualTextRequest(BaseModel):
39
+ kb_id: uuid.UUID
40
+ text: str
src/chatbot/service.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
-
3
  from sqlmodel.ext.asyncio.session import AsyncSession
4
-
5
  from .embedding import embedding_model
6
  from .models import KnowledgeBase, KnowledgeChunk
7
  from .utils import (
@@ -43,3 +43,29 @@ async def process_pdf_and_store(
43
  await session.commit()
44
 
45
  return {"kb_id": kb.id, "name": kb_name, "chunks_stored": len(chunk_objs)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from uuid import UUID
3
  from sqlmodel.ext.asyncio.session import AsyncSession
4
+ from sqlmodel import select
5
  from .embedding import embedding_model
6
  from .models import KnowledgeBase, KnowledgeChunk
7
  from .utils import (
 
43
  await session.commit()
44
 
45
  return {"kb_id": kb.id, "name": kb_name, "chunks_stored": len(chunk_objs)}
46
+
47
+ async def store_manual_text(kb_id: UUID, text: str, session: AsyncSession):
48
+ embedding = await embedding_model.embed_text(text)
49
+
50
+ result = await session.execute(
51
+ select(KnowledgeChunk).where(KnowledgeChunk.kb_id == kb_id)
52
+ )
53
+ existing = result.scalars().all()
54
+ next_index = len(existing)
55
+
56
+ new_chunk = KnowledgeChunk(
57
+ kb_id=kb_id,
58
+ chunk_index=next_index,
59
+ chunk_text=text,
60
+ embedding=embedding
61
+ )
62
+
63
+ session.add(new_chunk)
64
+ await session.commit()
65
+
66
+ return {
67
+ "kb_id": kb_id,
68
+ "chunk_index": next_index,
69
+ "status": "stored",
70
+ "text": text
71
+ }
src/main.py CHANGED
@@ -13,7 +13,7 @@ app = FastAPI(title="Yuvabe App API")
13
 
14
  app.include_router(home_router, prefix="/home", tags=["Home"])
15
 
16
- init_db()
17
 
18
  app.include_router(auth_router)
19
 
 
13
 
14
  app.include_router(home_router, prefix="/home", tags=["Home"])
15
 
16
+ # init_db()
17
 
18
  app.include_router(auth_router)
19