pluto90 commited on
Commit
21c1bff
Β·
verified Β·
1 Parent(s): f06dea6

Update app/core/embedding_engine.py

Browse files
Files changed (1) hide show
  1. app/core/embedding_engine.py +106 -233
app/core/embedding_engine.py CHANGED
@@ -1,233 +1,106 @@
1
- # # embedding_engine.py
2
-
3
- # import uuid, time
4
- # from qdrant_client import QdrantClient, models
5
- # from qdrant_client.http.models import Distance, VectorParams
6
- # from qdrant_client.http.exceptions import UnexpectedResponse
7
- # from sentence_transformers import SentenceTransformer
8
- # from app.core.config import QDRANT_URL, QDRANT_API_KEY
9
-
10
- # MODEL_PATH = "app/core/models/bge-base-en-v1.5"
11
- # embedder = SentenceTransformer(MODEL_PATH)
12
-
13
- # qdrant = QdrantClient(
14
- # url=QDRANT_URL,
15
- # api_key=QDRANT_API_KEY,
16
- # check_compatibility=False
17
- # )
18
-
19
- # COLLECTION_NAME = "smartnotes"
20
- # BATCH_SIZE = 10
21
-
22
-
23
- # def ensure_collection():
24
- # collections = qdrant.get_collections().collections
25
- # if COLLECTION_NAME not in [c.name for c in collections]:
26
- # qdrant.create_collection(
27
- # collection_name=COLLECTION_NAME,
28
- # vectors_config=VectorParams(
29
- # size=768,
30
- # distance=Distance.COSINE
31
- # ),
32
- # )
33
-
34
- # # βœ… Add this part
35
- # qdrant.create_payload_index(
36
- # collection_name=COLLECTION_NAME,
37
- # field_name="doc_id",
38
- # field_schema="keyword"
39
- # )
40
-
41
-
42
- # def embed_and_store(text_chunks, doc_id):
43
- # print(f"πŸ“Š Embedding and storing {len(text_chunks)} chunks...")
44
- # ensure_collection()
45
-
46
- # print(f"πŸ”Ή Embedding {len(text_chunks)} chunks...")
47
-
48
- # vectors = embed_documents(text_chunks)
49
-
50
- # points = [
51
- # models.PointStruct(
52
- # id=str(uuid.uuid4()),
53
- # vector=vectors[i],
54
- # payload={
55
- # "doc_id": doc_id,
56
- # "text": text_chunks[i],
57
- # "chunk_id": i,
58
- # "length": len(text_chunks[i])
59
- # },
60
- # )
61
- # for i in range(len(vectors))
62
- # ]
63
-
64
- # print("πŸ”Ή Uploading to Qdrant in batches...")
65
-
66
- # for i in range(0, len(points), BATCH_SIZE):
67
- # batch = points[i:i + BATCH_SIZE]
68
-
69
- # success = False
70
- # retries = 3
71
-
72
- # while not success and retries > 0:
73
- # try:
74
- # qdrant.upsert(
75
- # collection_name=COLLECTION_NAME,
76
- # points=batch
77
- # )
78
- # success = True
79
- # print(f" β†’ Uploaded batch {i // BATCH_SIZE + 1}")
80
-
81
- # except Exception as e:
82
- # print("❌ Qdrant error:", e)
83
- # retries -= 1
84
- # time.sleep(1.5) # πŸ”₯ increase wait
85
-
86
- # if not success:
87
- # print("⚠️ Skipping batch after retries")
88
-
89
- # time.sleep(0.4) # πŸ”₯ throttle
90
-
91
-
92
-
93
- # def embed_documents(texts):
94
- # vectors= []
95
-
96
- # for i in range(0, len(texts), 32):
97
- # batch = texts[i:i+32]
98
- # batch_vectors = embedder.encode(batch, show_progress_bar=False)
99
- # vectors.extend(batch_vectors.tolist())
100
-
101
- # return vectors
102
-
103
-
104
- # def embed_query(text):
105
- # return embedder.encode(
106
- # f"query: {text}",
107
- # normalize_embeddings=True
108
- # )
109
-
110
-
111
-
112
-
113
-
114
-
115
-
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
-
124
-
125
-
126
-
127
-
128
- # embedding_engine.py
129
- import uuid, time
130
- from qdrant_client import QdrantClient, models
131
- from qdrant_client.http.models import Distance, VectorParams
132
- from sentence_transformers import SentenceTransformer
133
- from app.core.config import QDRANT_URL, QDRANT_API_KEY
134
-
135
- MODEL_PATH = "app/core/models/bge-base-en-v1.5"
136
- embedder = SentenceTransformer(MODEL_PATH)
137
-
138
- qdrant = QdrantClient(
139
- url=QDRANT_URL,
140
- api_key=QDRANT_API_KEY,
141
- check_compatibility=False
142
- )
143
-
144
- COLLECTION_NAME = "smartnotes"
145
- BATCH_SIZE = 5 # βœ… reduced for free tier
146
-
147
-
148
- def ensure_collection():
149
- collections = qdrant.get_collections().collections
150
- if COLLECTION_NAME not in [c.name for c in collections]:
151
- qdrant.create_collection(
152
- collection_name=COLLECTION_NAME,
153
- vectors_config=VectorParams(size=768, distance=Distance.COSINE),
154
- )
155
- qdrant.create_payload_index(
156
- collection_name=COLLECTION_NAME,
157
- field_name="doc_id",
158
- field_schema="keyword"
159
- )
160
-
161
-
162
- def embed_and_store(text_chunks, doc_id):
163
- print(f"πŸ“Š Final chunks being embedded: {len(text_chunks)}")
164
- ensure_collection()
165
-
166
- vectors = embed_documents(text_chunks) # βœ… now uses correct doc prefix
167
-
168
- points = [
169
- models.PointStruct(
170
- id=str(uuid.uuid4()),
171
- vector=vectors[i],
172
- payload={
173
- "doc_id": doc_id,
174
- "text": text_chunks[i],
175
- "chunk_id": i,
176
- "length": len(text_chunks[i])
177
- },
178
- )
179
- for i in range(len(vectors))
180
- ]
181
-
182
- failed_batches = []
183
-
184
- for i in range(0, len(points), BATCH_SIZE):
185
- batch = points[i:i + BATCH_SIZE]
186
- batch_num = i // BATCH_SIZE + 1
187
- success = False
188
-
189
- for attempt in range(4): # βœ… 4 attempts with exponential backoff
190
- try:
191
- qdrant.upsert(collection_name=COLLECTION_NAME, points=batch)
192
- success = True
193
- print(f" β†’ Batch {batch_num} uploaded")
194
- break
195
- except Exception as e:
196
- wait = 2 ** attempt # 1s, 2s, 4s, 8s
197
- print(f" ⚠️ Batch {batch_num} attempt {attempt+1} failed: {e} | retrying in {wait}s")
198
- time.sleep(wait)
199
-
200
- if not success:
201
- failed_batches.append(batch_num)
202
- print(f" ❌ Batch {batch_num} permanently failed")
203
-
204
- time.sleep(0.6) # βœ… throttle between successful batches
205
-
206
- if failed_batches:
207
- # βœ… raise so the caller (routes.py) knows something went wrong
208
- raise RuntimeError(f"Failed to upload batches: {failed_batches}")
209
-
210
- print(f"βœ… All batches uploaded for doc_id={doc_id}")
211
-
212
-
213
- def embed_documents(texts):
214
- """Embed document chunks with correct BGE prefix and normalization."""
215
- prefixed = [f"Represent this sentence: {t}" for t in texts] # βœ… correct BGE doc prefix
216
- vectors = []
217
- for i in range(0, len(prefixed), 32):
218
- batch = prefixed[i:i + 32]
219
- batch_vectors = embedder.encode(
220
- batch, normalize_embeddings=True, show_progress_bar=False)
221
-
222
- vectors.extend(batch_vectors.tolist())
223
- return vectors
224
-
225
-
226
- def embed_query(text):
227
- """Embed a search query β€” BGE uses 'query:' prefix for retrieval."""
228
- return embedder.encode(
229
- f"query: {text}",
230
- normalize_embeddings=True
231
- ).tolist() # βœ… always return list, not numpy array
232
-
233
-
 
1
+ # embedding_engine.py
2
+ import uuid, time
3
+ from qdrant_client import QdrantClient, models
4
+ from qdrant_client.http.models import Distance, VectorParams
5
+ from sentence_transformers import SentenceTransformer
6
+ from app.core.config import QDRANT_URL, QDRANT_API_KEY
7
+
8
+ MODEL_PATH = "app/core/models/bge-base-en-v1.5"
9
+ embedder = SentenceTransformer(MODEL_PATH)
10
+
11
+ qdrant = QdrantClient(
12
+ url=QDRANT_URL,
13
+ api_key=QDRANT_API_KEY,
14
+ check_compatibility=False
15
+ )
16
+
17
+ COLLECTION_NAME = "smartnotes"
18
+ BATCH_SIZE = 5 # βœ… reduced for free tier
19
+
20
+
21
+ def ensure_collection():
22
+ collections = qdrant.get_collections().collections
23
+ if COLLECTION_NAME not in [c.name for c in collections]:
24
+ qdrant.create_collection(
25
+ collection_name=COLLECTION_NAME,
26
+ vectors_config=VectorParams(size=768, distance=Distance.COSINE),
27
+ )
28
+ qdrant.create_payload_index(
29
+ collection_name=COLLECTION_NAME,
30
+ field_name="doc_id",
31
+ field_schema="keyword"
32
+ )
33
+
34
+
35
+ def embed_and_store(text_chunks, doc_id):
36
+ print(f"πŸ“Š Final chunks being embedded: {len(text_chunks)}")
37
+ ensure_collection()
38
+
39
+ vectors = embed_documents(text_chunks) # βœ… now uses correct doc prefix
40
+
41
+ points = [
42
+ models.PointStruct(
43
+ id=str(uuid.uuid4()),
44
+ vector=vectors[i],
45
+ payload={
46
+ "doc_id": doc_id,
47
+ "text": text_chunks[i],
48
+ "chunk_id": i,
49
+ "length": len(text_chunks[i])
50
+ },
51
+ )
52
+ for i in range(len(vectors))
53
+ ]
54
+
55
+ failed_batches = []
56
+
57
+ for i in range(0, len(points), BATCH_SIZE):
58
+ batch = points[i:i + BATCH_SIZE]
59
+ batch_num = i // BATCH_SIZE + 1
60
+ success = False
61
+
62
+ for attempt in range(4): # βœ… 4 attempts with exponential backoff
63
+ try:
64
+ qdrant.upsert(collection_name=COLLECTION_NAME, points=batch)
65
+ success = True
66
+ print(f" β†’ Batch {batch_num} uploaded")
67
+ break
68
+ except Exception as e:
69
+ wait = 2 ** attempt # 1s, 2s, 4s, 8s
70
+ print(f" ⚠️ Batch {batch_num} attempt {attempt+1} failed: {e} | retrying in {wait}s")
71
+ time.sleep(wait)
72
+
73
+ if not success:
74
+ failed_batches.append(batch_num)
75
+ print(f" ❌ Batch {batch_num} permanently failed")
76
+
77
+ time.sleep(0.6) # βœ… throttle between successful batches
78
+
79
+ if failed_batches:
80
+ # βœ… raise so the caller (routes.py) knows something went wrong
81
+ raise RuntimeError(f"Failed to upload batches: {failed_batches}")
82
+
83
+ print(f"βœ… All batches uploaded for doc_id={doc_id}")
84
+
85
+
86
+ def embed_documents(texts):
87
+ """Embed document chunks with correct BGE prefix and normalization."""
88
+ prefixed = [f"Represent this sentence: {t}" for t in texts] # βœ… correct BGE doc prefix
89
+ vectors = []
90
+ for i in range(0, len(prefixed), 32):
91
+ batch = prefixed[i:i + 32]
92
+ batch_vectors = embedder.encode(
93
+ batch, normalize_embeddings=True, show_progress_bar=False)
94
+
95
+ vectors.extend(batch_vectors.tolist())
96
+ return vectors
97
+
98
+
99
+ def embed_query(text):
100
+ """Embed a search query β€” BGE uses 'query:' prefix for retrieval."""
101
+ return embedder.encode(
102
+ f"query: {text}",
103
+ normalize_embeddings=True
104
+ ).tolist() # βœ… always return list, not numpy array
105
+
106
+