Spaces:
Sleeping
Sleeping
Commit
·
c4a0174
0
Parent(s):
initial commit with the 0.1 version of the app
Browse files- .gitignore +22 -0
- .python-version +1 -0
- README.md +0 -0
- main.py +236 -0
- pyproject.toml +19 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# environment variables
|
| 13 |
+
.env
|
| 14 |
+
|
| 15 |
+
# Hugging Face cache
|
| 16 |
+
.hf_cache/
|
| 17 |
+
|
| 18 |
+
# Milvus database
|
| 19 |
+
milvus_binary_quantized.db
|
| 20 |
+
|
| 21 |
+
# data
|
| 22 |
+
documents/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
README.md
ADDED
|
File without changes
|
main.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from langchain.chat_models import init_chat_model
|
| 4 |
+
from llama_index.core import SimpleDirectoryReader
|
| 5 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pymilvus import MilvusClient, DataType
|
| 8 |
+
import logging
|
| 9 |
+
from langchain_core.messages import HumanMessage
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
DOCS_DIR = "documents"
|
| 18 |
+
MODEL_NAME = "gpt-4.1"
|
| 19 |
+
TEMPERATURE = 0.2
|
| 20 |
+
COLLECTION_NAME = "fast_rag"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def batch_iterate(items, batch_size):
|
| 24 |
+
"""Iterate over items in batches."""
|
| 25 |
+
for i in range(0, len(items), batch_size):
|
| 26 |
+
yield items[i:i + batch_size]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
llm = init_chat_model(MODEL_NAME, model_provider="openai", temperature=TEMPERATURE)
|
| 30 |
+
|
| 31 |
+
## Generate binary embeddings
|
| 32 |
+
def generate_binary_embeddings():
|
| 33 |
+
"""Generate binary embeddings from documents."""
|
| 34 |
+
try:
|
| 35 |
+
# Define loader
|
| 36 |
+
loader = SimpleDirectoryReader(
|
| 37 |
+
input_dir=DOCS_DIR,
|
| 38 |
+
required_exts=[".pdf"],
|
| 39 |
+
recursive=True,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
docs = loader.load_data()
|
| 43 |
+
documents = [doc.text for doc in docs]
|
| 44 |
+
|
| 45 |
+
if not documents:
|
| 46 |
+
logger.error("No documents found in the documents directory.")
|
| 47 |
+
return [], []
|
| 48 |
+
|
| 49 |
+
# Generate embeddings
|
| 50 |
+
embedding_model = HuggingFaceEmbedding(
|
| 51 |
+
model_name="BAAI/bge-large-en-v1.5",
|
| 52 |
+
trust_remote_code=True,
|
| 53 |
+
cache_folder=".hf_cache",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
binary_embeddings = []
|
| 57 |
+
|
| 58 |
+
for context in batch_iterate(documents, batch_size=512):
|
| 59 |
+
# generate float32 embeddings
|
| 60 |
+
batch_embeddings = embedding_model.get_text_embedding_batch(context)
|
| 61 |
+
|
| 62 |
+
# convert float32 to binary vectors
|
| 63 |
+
embeds_array = np.array(batch_embeddings)
|
| 64 |
+
binary_embeds = np.where(embeds_array > 0, 1, 0).astype(np.uint8)
|
| 65 |
+
|
| 66 |
+
# convert to bytes array
|
| 67 |
+
packed_embeds = np.packbits(binary_embeds, axis=1)
|
| 68 |
+
byte_embeds = [vec.tobytes() for vec in packed_embeds]
|
| 69 |
+
|
| 70 |
+
binary_embeddings.extend(byte_embeds)
|
| 71 |
+
|
| 72 |
+
logger.info(f"Generated {len(binary_embeddings)} binary embeddings")
|
| 73 |
+
return documents, binary_embeddings
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"Error generating embeddings: {e}")
|
| 77 |
+
return [], []
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
documents, binary_embeddings = generate_binary_embeddings()
|
| 81 |
+
|
| 82 |
+
## Vector indexing
|
| 83 |
+
client = MilvusClient("milvus_binary_quantized.db")
|
| 84 |
+
|
| 85 |
+
# Initialize client and schema
|
| 86 |
+
def create_collection(documents, embeddings):
|
| 87 |
+
try:
|
| 88 |
+
if client.has_collection(COLLECTION_NAME):
|
| 89 |
+
logger.info(f"Collection {COLLECTION_NAME} already exists, dropping it...")
|
| 90 |
+
client.drop_collection(COLLECTION_NAME)
|
| 91 |
+
|
| 92 |
+
# Initialize client
|
| 93 |
+
schema = client.create_schema(
|
| 94 |
+
auto_id=True,
|
| 95 |
+
enable_dynamic_fields=True,
|
| 96 |
+
)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.error(f"Error creating collection: {e}")
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
# Add primary key field
|
| 102 |
+
schema.add_field(
|
| 103 |
+
field_name="id",
|
| 104 |
+
datatype=DataType.INT64,
|
| 105 |
+
is_primary=True,
|
| 106 |
+
auto_id=True,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Add fields to schema
|
| 110 |
+
schema.add_field(
|
| 111 |
+
field_name="context",
|
| 112 |
+
datatype=DataType.VARCHAR,
|
| 113 |
+
max_length=65535, # max length for VARCHAR
|
| 114 |
+
)
|
| 115 |
+
schema.add_field(
|
| 116 |
+
field_name="binary_vector",
|
| 117 |
+
datatype=DataType.BINARY_VECTOR,
|
| 118 |
+
dim=1024, # dimension for binary vector
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Create index params for binary vector
|
| 122 |
+
index_params = client.prepare_index_params()
|
| 123 |
+
index_params.add_index(
|
| 124 |
+
field_name="binary_vector",
|
| 125 |
+
index_name="binary_vector_index",
|
| 126 |
+
index_type="BIN_FLAT", # Exact search for binary vectors
|
| 127 |
+
metric_type="HAMMING", # Hamming distance for binary vectors
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Create collection with schema and index
|
| 131 |
+
client.create_collection(
|
| 132 |
+
collection_name=COLLECTION_NAME,
|
| 133 |
+
schema=schema,
|
| 134 |
+
index_params=index_params,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Insert data into collection
|
| 138 |
+
client.insert(
|
| 139 |
+
collection_name=COLLECTION_NAME,
|
| 140 |
+
data=[
|
| 141 |
+
{
|
| 142 |
+
"context": context,
|
| 143 |
+
"binary_vector": binary_embedding
|
| 144 |
+
}
|
| 145 |
+
for context, binary_embedding in zip(documents, embeddings)
|
| 146 |
+
]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
create_collection(documents, binary_embeddings)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_query_embeddings(query: str) -> bytes:
|
| 153 |
+
"""Get query embeddings."""
|
| 154 |
+
try:
|
| 155 |
+
embedding_model = HuggingFaceEmbedding(
|
| 156 |
+
model_name="BAAI/bge-large-en-v1.5",
|
| 157 |
+
trust_remote_code=True,
|
| 158 |
+
cache_folder=".hf_cache",
|
| 159 |
+
)
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logger.error(f"Error getting query embeddings: {e}")
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
# Generate float32 embeddings
|
| 165 |
+
query_embedding = embedding_model.get_text_embedding(query)
|
| 166 |
+
|
| 167 |
+
# Convert float32 to binary vector
|
| 168 |
+
binary_vector = np.where(np.array(query_embedding) > 0, 1, 0).astype(np.uint8)
|
| 169 |
+
|
| 170 |
+
# Convert to bytes array
|
| 171 |
+
packed_vector = np.packbits(binary_vector, axis=0)
|
| 172 |
+
|
| 173 |
+
return packed_vector.tobytes()
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def search_documents(query: str, limit: int = 5):
|
| 177 |
+
"""Search documents using binary embeddings."""
|
| 178 |
+
try:
|
| 179 |
+
binary_query = get_query_embeddings(query)
|
| 180 |
+
if binary_query is None:
|
| 181 |
+
logger.error("Failed to generate query embeddings")
|
| 182 |
+
return []
|
| 183 |
+
|
| 184 |
+
search_results = client.search(
|
| 185 |
+
collection_name=COLLECTION_NAME,
|
| 186 |
+
data=[binary_query],
|
| 187 |
+
anns_field="binary_vector",
|
| 188 |
+
search_params={
|
| 189 |
+
"metric_type": "HAMMING",
|
| 190 |
+
},
|
| 191 |
+
output_fields=["context"],
|
| 192 |
+
limit=limit,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# logger.info(f"Search results: {search_results}")
|
| 196 |
+
|
| 197 |
+
if not search_results:
|
| 198 |
+
logger.error("No search results found")
|
| 199 |
+
return []
|
| 200 |
+
|
| 201 |
+
contexts = [res.entity.context for res in search_results[0]]
|
| 202 |
+
|
| 203 |
+
return contexts
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
logger.error(f"Error searching documents: {e}")
|
| 207 |
+
return []
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# Test the search functionality
|
| 211 |
+
query = "authors of the document"
|
| 212 |
+
contexts = search_documents(query, limit=5)
|
| 213 |
+
|
| 214 |
+
prompt = f"""
|
| 215 |
+
# Role and objective
|
| 216 |
+
You are a helpful assistant that can answer questions about the following context.
|
| 217 |
+
|
| 218 |
+
# Intstructions
|
| 219 |
+
Given the context information, answer the user's query.
|
| 220 |
+
If the context information is not relevant to the user's query, say "I don't know".
|
| 221 |
+
|
| 222 |
+
# Context
|
| 223 |
+
{contexts}
|
| 224 |
+
|
| 225 |
+
# User's query
|
| 226 |
+
{query}
|
| 227 |
+
|
| 228 |
+
# Answer
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
human_message = HumanMessage(content=prompt)
|
| 232 |
+
print(f"Human message: {human_message}")
|
| 233 |
+
|
| 234 |
+
response = llm.invoke(input=[human_message])
|
| 235 |
+
|
| 236 |
+
print(f"Response from the model: {response.content}")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "rag-w-binary-quant"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"black>=25.1.0",
|
| 9 |
+
"dotenv>=0.9.9",
|
| 10 |
+
"isort>=6.0.1",
|
| 11 |
+
"langchain>=0.3.27",
|
| 12 |
+
"langchain-community>=0.3.27",
|
| 13 |
+
"langchain-openai>=0.3.28",
|
| 14 |
+
"llama-index>=0.13.0",
|
| 15 |
+
"llama-index-embeddings-huggingface>=0.6.0",
|
| 16 |
+
"logging>=0.4.9.6",
|
| 17 |
+
"numpy>=2.3.2",
|
| 18 |
+
"pymilvus>=2.5.14",
|
| 19 |
+
]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|