serverdaun commited on
Commit
c4a0174
·
0 Parent(s):

initial commit with the 0.1 version of the app

Browse files
Files changed (6) hide show
  1. .gitignore +22 -0
  2. .python-version +1 -0
  3. README.md +0 -0
  4. main.py +236 -0
  5. pyproject.toml +19 -0
  6. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # environment variables
13
+ .env
14
+
15
+ # Hugging Face cache
16
+ .hf_cache/
17
+
18
+ # Milvus database
19
+ milvus_binary_quantized.db
20
+
21
+ # data
22
+ documents/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md ADDED
File without changes
main.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain.chat_models import init_chat_model
4
+ from llama_index.core import SimpleDirectoryReader
5
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
+ import numpy as np
7
+ from pymilvus import MilvusClient, DataType
8
+ import logging
9
+ from langchain_core.messages import HumanMessage
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ load_dotenv()
16
+
17
+ DOCS_DIR = "documents"
18
+ MODEL_NAME = "gpt-4.1"
19
+ TEMPERATURE = 0.2
20
+ COLLECTION_NAME = "fast_rag"
21
+
22
+
23
+ def batch_iterate(items, batch_size):
24
+ """Iterate over items in batches."""
25
+ for i in range(0, len(items), batch_size):
26
+ yield items[i:i + batch_size]
27
+
28
+
29
+ llm = init_chat_model(MODEL_NAME, model_provider="openai", temperature=TEMPERATURE)
30
+
31
+ ## Generate binary embeddings
32
+ def generate_binary_embeddings():
33
+ """Generate binary embeddings from documents."""
34
+ try:
35
+ # Define loader
36
+ loader = SimpleDirectoryReader(
37
+ input_dir=DOCS_DIR,
38
+ required_exts=[".pdf"],
39
+ recursive=True,
40
+ )
41
+
42
+ docs = loader.load_data()
43
+ documents = [doc.text for doc in docs]
44
+
45
+ if not documents:
46
+ logger.error("No documents found in the documents directory.")
47
+ return [], []
48
+
49
+ # Generate embeddings
50
+ embedding_model = HuggingFaceEmbedding(
51
+ model_name="BAAI/bge-large-en-v1.5",
52
+ trust_remote_code=True,
53
+ cache_folder=".hf_cache",
54
+ )
55
+
56
+ binary_embeddings = []
57
+
58
+ for context in batch_iterate(documents, batch_size=512):
59
+ # generate float32 embeddings
60
+ batch_embeddings = embedding_model.get_text_embedding_batch(context)
61
+
62
+ # convert float32 to binary vectors
63
+ embeds_array = np.array(batch_embeddings)
64
+ binary_embeds = np.where(embeds_array > 0, 1, 0).astype(np.uint8)
65
+
66
+ # convert to bytes array
67
+ packed_embeds = np.packbits(binary_embeds, axis=1)
68
+ byte_embeds = [vec.tobytes() for vec in packed_embeds]
69
+
70
+ binary_embeddings.extend(byte_embeds)
71
+
72
+ logger.info(f"Generated {len(binary_embeddings)} binary embeddings")
73
+ return documents, binary_embeddings
74
+
75
+ except Exception as e:
76
+ logger.error(f"Error generating embeddings: {e}")
77
+ return [], []
78
+
79
+
80
+ documents, binary_embeddings = generate_binary_embeddings()
81
+
82
+ ## Vector indexing
83
+ client = MilvusClient("milvus_binary_quantized.db")
84
+
85
+ # Initialize client and schema
86
+ def create_collection(documents, embeddings):
87
+ try:
88
+ if client.has_collection(COLLECTION_NAME):
89
+ logger.info(f"Collection {COLLECTION_NAME} already exists, dropping it...")
90
+ client.drop_collection(COLLECTION_NAME)
91
+
92
+ # Initialize client
93
+ schema = client.create_schema(
94
+ auto_id=True,
95
+ enable_dynamic_fields=True,
96
+ )
97
+ except Exception as e:
98
+ logger.error(f"Error creating collection: {e}")
99
+ return None
100
+
101
+ # Add primary key field
102
+ schema.add_field(
103
+ field_name="id",
104
+ datatype=DataType.INT64,
105
+ is_primary=True,
106
+ auto_id=True,
107
+ )
108
+
109
+ # Add fields to schema
110
+ schema.add_field(
111
+ field_name="context",
112
+ datatype=DataType.VARCHAR,
113
+ max_length=65535, # max length for VARCHAR
114
+ )
115
+ schema.add_field(
116
+ field_name="binary_vector",
117
+ datatype=DataType.BINARY_VECTOR,
118
+ dim=1024, # dimension for binary vector
119
+ )
120
+
121
+ # Create index params for binary vector
122
+ index_params = client.prepare_index_params()
123
+ index_params.add_index(
124
+ field_name="binary_vector",
125
+ index_name="binary_vector_index",
126
+ index_type="BIN_FLAT", # Exact search for binary vectors
127
+ metric_type="HAMMING", # Hamming distance for binary vectors
128
+ )
129
+
130
+ # Create collection with schema and index
131
+ client.create_collection(
132
+ collection_name=COLLECTION_NAME,
133
+ schema=schema,
134
+ index_params=index_params,
135
+ )
136
+
137
+ # Insert data into collection
138
+ client.insert(
139
+ collection_name=COLLECTION_NAME,
140
+ data=[
141
+ {
142
+ "context": context,
143
+ "binary_vector": binary_embedding
144
+ }
145
+ for context, binary_embedding in zip(documents, embeddings)
146
+ ]
147
+ )
148
+
149
+ create_collection(documents, binary_embeddings)
150
+
151
+
152
+ def get_query_embeddings(query: str) -> bytes:
153
+ """Get query embeddings."""
154
+ try:
155
+ embedding_model = HuggingFaceEmbedding(
156
+ model_name="BAAI/bge-large-en-v1.5",
157
+ trust_remote_code=True,
158
+ cache_folder=".hf_cache",
159
+ )
160
+ except Exception as e:
161
+ logger.error(f"Error getting query embeddings: {e}")
162
+ return None
163
+
164
+ # Generate float32 embeddings
165
+ query_embedding = embedding_model.get_text_embedding(query)
166
+
167
+ # Convert float32 to binary vector
168
+ binary_vector = np.where(np.array(query_embedding) > 0, 1, 0).astype(np.uint8)
169
+
170
+ # Convert to bytes array
171
+ packed_vector = np.packbits(binary_vector, axis=0)
172
+
173
+ return packed_vector.tobytes()
174
+
175
+
176
+ def search_documents(query: str, limit: int = 5):
177
+ """Search documents using binary embeddings."""
178
+ try:
179
+ binary_query = get_query_embeddings(query)
180
+ if binary_query is None:
181
+ logger.error("Failed to generate query embeddings")
182
+ return []
183
+
184
+ search_results = client.search(
185
+ collection_name=COLLECTION_NAME,
186
+ data=[binary_query],
187
+ anns_field="binary_vector",
188
+ search_params={
189
+ "metric_type": "HAMMING",
190
+ },
191
+ output_fields=["context"],
192
+ limit=limit,
193
+ )
194
+
195
+ # logger.info(f"Search results: {search_results}")
196
+
197
+ if not search_results:
198
+ logger.error("No search results found")
199
+ return []
200
+
201
+ contexts = [res.entity.context for res in search_results[0]]
202
+
203
+ return contexts
204
+
205
+ except Exception as e:
206
+ logger.error(f"Error searching documents: {e}")
207
+ return []
208
+
209
+
210
+ # Test the search functionality
211
+ query = "authors of the document"
212
+ contexts = search_documents(query, limit=5)
213
+
214
+ prompt = f"""
215
+ # Role and objective
216
+ You are a helpful assistant that can answer questions about the following context.
217
+
218
+ # Intstructions
219
+ Given the context information, answer the user's query.
220
+ If the context information is not relevant to the user's query, say "I don't know".
221
+
222
+ # Context
223
+ {contexts}
224
+
225
+ # User's query
226
+ {query}
227
+
228
+ # Answer
229
+ """
230
+
231
+ human_message = HumanMessage(content=prompt)
232
+ print(f"Human message: {human_message}")
233
+
234
+ response = llm.invoke(input=[human_message])
235
+
236
+ print(f"Response from the model: {response.content}")
pyproject.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rag-w-binary-quant"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "black>=25.1.0",
9
+ "dotenv>=0.9.9",
10
+ "isort>=6.0.1",
11
+ "langchain>=0.3.27",
12
+ "langchain-community>=0.3.27",
13
+ "langchain-openai>=0.3.28",
14
+ "llama-index>=0.13.0",
15
+ "llama-index-embeddings-huggingface>=0.6.0",
16
+ "logging>=0.4.9.6",
17
+ "numpy>=2.3.2",
18
+ "pymilvus>=2.5.14",
19
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff