pluto90 commited on
Commit
d6dd5a9
Β·
verified Β·
1 Parent(s): 5e47fb2

Update app/core/embedding_engine.py

Browse files
Files changed (1) hide show
  1. app/core/embedding_engine.py +151 -65
app/core/embedding_engine.py CHANGED
@@ -1,66 +1,152 @@
1
- # embedding_engine.py
2
-
3
- import uuid
4
- from qdrant_client import QdrantClient, models
5
- from qdrant_client.http.models import Distance, VectorParams
6
- from sentence_transformers import SentenceTransformer
7
- from app.core.config import QDRANT_URL, QDRANT_API_KEY
8
-
9
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
10
-
11
- qdrant = QdrantClient(
12
- url=QDRANT_URL,
13
- api_key=QDRANT_API_KEY,
14
- check_compatibility=False
15
- )
16
-
17
- COLLECTION_NAME = "smartnotes"
18
- BATCH_SIZE = 100
19
-
20
-
21
- def ensure_collection():
22
- collections = qdrant.get_collections().collections
23
- if COLLECTION_NAME not in [c.name for c in collections]:
24
- qdrant.create_collection(
25
- collection_name=COLLECTION_NAME,
26
- vectors_config=VectorParams(
27
- size=384,
28
- distance=Distance.COSINE
29
- ),
30
- )
31
-
32
- # βœ… Add this part
33
- qdrant.create_payload_index(
34
- collection_name=COLLECTION_NAME,
35
- field_name="doc_id",
36
- field_schema="keyword"
37
- )
38
-
39
-
40
-
41
- def embed_and_store(text_chunks, doc_id):
42
- """Embed chunks and store them in Qdrant efficiently."""
43
- ensure_collection()
44
- print(f"πŸ”Ή Embedding {len(text_chunks)} chunks...")
45
-
46
- # Generate embeddings
47
- vectors = embedder.encode(text_chunks, show_progress_bar=True).tolist()
48
-
49
- # Prepare points
50
- points = [
51
- models.PointStruct(
52
- id=str(uuid.uuid4()),
53
- vector=vectors[i],
54
- payload={"doc_id": doc_id, "text": text_chunks[i]},
55
- )
56
- for i in range(len(vectors))
57
- ]
58
-
59
- # βœ… Upsert in small batches to avoid timeouts
60
- print("πŸ”Ή Uploading to Qdrant in batches...")
61
- for i in range(0, len(points), BATCH_SIZE):
62
- batch = points[i:i + BATCH_SIZE]
63
- qdrant.upsert(collection_name=COLLECTION_NAME, points=batch)
64
- print(f" β†’ Uploaded batch {i // BATCH_SIZE + 1}/{len(points) // BATCH_SIZE + 1}")
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  print("βœ… All embeddings stored successfully!")
 
1
+ # # embedding_engine.py
2
+ # import uuid
3
+ # from qdrant_client import QdrantClient, models
4
+ # from qdrant_client.http.models import Distance, VectorParams
5
+ # from sentence_transformers import SentenceTransformer
6
+ # from app.core.config import QDRANT_URL, QDRANT_API_KEY
7
+
8
+ # embedder = SentenceTransformer("all-MiniLM-L6-v2")
9
+
10
+ # qdrant = QdrantClient(
11
+ # url=QDRANT_URL,
12
+ # api_key=QDRANT_API_KEY,
13
+ # check_compatibility=False
14
+ # )
15
+
16
+ # COLLECTION_NAME = "smartnotes"
17
+ # BATCH_SIZE = 100
18
+
19
+
20
+ # def ensure_collection():
21
+ # collections = qdrant.get_collections().collections
22
+ # if COLLECTION_NAME not in [c.name for c in collections]:
23
+ # qdrant.create_collection(
24
+ # collection_name=COLLECTION_NAME,
25
+ # vectors_config=VectorParams(
26
+ # size=384,
27
+ # distance=Distance.COSINE
28
+ # ),
29
+ # )
30
+
31
+ # # βœ… Add this part
32
+ # qdrant.create_payload_index(
33
+ # collection_name=COLLECTION_NAME,
34
+ # field_name="doc_id",
35
+ # field_schema="keyword"
36
+ # )
37
+
38
+
39
+
40
+ # def embed_and_store(text_chunks, doc_id):
41
+ # """Embed chunks and store them in Qdrant efficiently."""
42
+ # ensure_collection()
43
+ # print(f"πŸ”Ή Embedding {len(text_chunks)} chunks...")
44
+
45
+ # # Generate embeddings
46
+ # vectors = embedder.encode(text_chunks, show_progress_bar=True).tolist()
47
+
48
+ # # Prepare points
49
+ # points = [
50
+ # models.PointStruct(
51
+ # id=str(uuid.uuid4()),
52
+ # vector=vectors[i],
53
+ # payload={"doc_id": doc_id, "text": text_chunks[i]},
54
+ # )
55
+ # for i in range(len(vectors))
56
+ # ]
57
+
58
+ # # βœ… Upsert in small batches to avoid timeouts
59
+ # print("πŸ”Ή Uploading to Qdrant in batches...")
60
+ # for i in range(0, len(points), BATCH_SIZE):
61
+ # batch = points[i:i + BATCH_SIZE]
62
+ # qdrant.upsert(collection_name=COLLECTION_NAME, points=batch)
63
+ # print(f" β†’ Uploaded batch {i // BATCH_SIZE + 1}/{len(points) // BATCH_SIZE + 1}")
64
+
65
+ # print("βœ… All embeddings stored successfully!")
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+ # embedding_engine.py
83
+
84
+ import uuid
85
+ from qdrant_client import QdrantClient, models
86
+ from qdrant_client.http.models import Distance, VectorParams
87
+ from sentence_transformers import SentenceTransformer
88
+ from app.core.config import QDRANT_URL, QDRANT_API_KEY
89
+ # from config import QDRANT_URL, QDRANT_API_KEY
90
+
91
+ # embedder = SentenceTransformer("all-MiniLM-L6-v2")
92
+ # embedder.save("models/all-MiniLM-L6-v2")
93
+
94
+
95
+ MODEL_PATH = "models/all-MiniLM-L6-v2"
96
+ embedder = SentenceTransformer(MODEL_PATH)
97
+
98
+ qdrant = QdrantClient(
99
+ url=QDRANT_URL,
100
+ api_key=QDRANT_API_KEY,
101
+ check_compatibility=False
102
+ )
103
+
104
+ COLLECTION_NAME = "smartnotes"
105
+ BATCH_SIZE = 100
106
+
107
+
108
+ def ensure_collection():
109
+ collections = qdrant.get_collections().collections
110
+ if COLLECTION_NAME not in [c.name for c in collections]:
111
+ qdrant.create_collection(
112
+ collection_name=COLLECTION_NAME,
113
+ vectors_config=VectorParams(
114
+ size=384,
115
+ distance=Distance.COSINE
116
+ ),
117
+ )
118
+
119
+ # βœ… Add this part
120
+ qdrant.create_payload_index(
121
+ collection_name=COLLECTION_NAME,
122
+ field_name="doc_id",
123
+ field_schema="keyword"
124
+ )
125
+
126
+
127
+ def embed_and_store(text_chunks, doc_id):
128
+ """Embed chunks and store them in Qdrant efficiently."""
129
+ ensure_collection()
130
+ print(f"πŸ”Ή Embedding {len(text_chunks)} chunks...")
131
+
132
+ # Generate embeddings
133
+ vectors = embedder.encode(text_chunks, show_progress_bar=True).tolist()
134
+
135
+ # Prepare points
136
+ points = [
137
+ models.PointStruct(
138
+ id=str(uuid.uuid4()),
139
+ vector=vectors[i],
140
+ payload={"doc_id": doc_id, "text": text_chunks[i]},
141
+ )
142
+ for i in range(len(vectors))
143
+ ]
144
+
145
+ # βœ… Upsert in small batches to avoid timeouts
146
+ print("πŸ”Ή Uploading to Qdrant in batches...")
147
+ for i in range(0, len(points), BATCH_SIZE):
148
+ batch = points[i:i + BATCH_SIZE]
149
+ qdrant.upsert(collection_name=COLLECTION_NAME, points=batch)
150
+ print(f" β†’ Uploaded batch {i // BATCH_SIZE + 1}/{len(points) // BATCH_SIZE + 1}")
151
+
152
  print("βœ… All embeddings stored successfully!")