Spaces:
Sleeping
Sleeping
Commit
·
2f81d82
1
Parent(s):
c7a21f9
Implement collection management in Milvus: drop collection during cleanup and on startup if no documents exist. Update embedding generation to clarify data types and improve error handling. Refactor vector store collection creation logic to avoid unnecessary drops.
Browse files- app.py +26 -0
- src/embedding_generator.py +4 -4
- src/rag_pipeline.py +3 -0
- src/vector_store.py +10 -13
app.py
CHANGED
|
@@ -38,6 +38,13 @@ def cleanup_documents():
|
|
| 38 |
for f in files:
|
| 39 |
if os.path.isfile(f):
|
| 40 |
os.remove(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
print("Cleanup complete.")
|
| 42 |
|
| 43 |
|
|
@@ -45,6 +52,23 @@ def cleanup_documents():
|
|
| 45 |
atexit.register(cleanup_documents)
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def index_documents(file_list):
|
| 49 |
"""Index documents from a list of files."""
|
| 50 |
if not file_list:
|
|
@@ -113,4 +137,6 @@ with gr.Blocks() as demo:
|
|
| 113 |
if __name__ == "__main__":
|
| 114 |
# Ensure the documents directory exists from the start
|
| 115 |
os.makedirs(DOCS_DIR, exist_ok=True)
|
|
|
|
|
|
|
| 116 |
demo.launch()
|
|
|
|
| 38 |
for f in files:
|
| 39 |
if os.path.isfile(f):
|
| 40 |
os.remove(f)
|
| 41 |
+
# Also drop the Milvus collection to avoid stale state between restarts
|
| 42 |
+
try:
|
| 43 |
+
if milvus_client and milvus_client.has_collection(COLLECTION_NAME):
|
| 44 |
+
milvus_client.drop_collection(COLLECTION_NAME)
|
| 45 |
+
print(f"Dropped collection {COLLECTION_NAME} during cleanup.")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"Error dropping collection during cleanup: {e}")
|
| 48 |
print("Cleanup complete.")
|
| 49 |
|
| 50 |
|
|
|
|
| 52 |
atexit.register(cleanup_documents)
|
| 53 |
|
| 54 |
|
| 55 |
+
def reset_collection_if_no_docs():
|
| 56 |
+
"""Drop existing collection on startup if there are no documents on disk."""
|
| 57 |
+
try:
|
| 58 |
+
os.makedirs(DOCS_DIR, exist_ok=True)
|
| 59 |
+
files = glob.glob(os.path.join(DOCS_DIR, "*"))
|
| 60 |
+
has_docs = any(os.path.isfile(f) for f in files)
|
| 61 |
+
if (
|
| 62 |
+
not has_docs
|
| 63 |
+
and milvus_client
|
| 64 |
+
and milvus_client.has_collection(COLLECTION_NAME)
|
| 65 |
+
):
|
| 66 |
+
milvus_client.drop_collection(COLLECTION_NAME)
|
| 67 |
+
print(f"No documents found. Dropped existing collection {COLLECTION_NAME}.")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Error resetting collection on startup: {e}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
def index_documents(file_list):
|
| 73 |
"""Index documents from a list of files."""
|
| 74 |
if not file_list:
|
|
|
|
| 137 |
if __name__ == "__main__":
|
| 138 |
# Ensure the documents directory exists from the start
|
| 139 |
os.makedirs(DOCS_DIR, exist_ok=True)
|
| 140 |
+
# Reset collection state if there are no documents at startup
|
| 141 |
+
reset_collection_if_no_docs()
|
| 142 |
demo.launch()
|
src/embedding_generator.py
CHANGED
|
@@ -32,7 +32,7 @@ def generate_document_embeddings(
|
|
| 32 |
Returns:
|
| 33 |
A list of document embeddings
|
| 34 |
"""
|
| 35 |
-
binary_embeddings = []
|
| 36 |
|
| 37 |
try:
|
| 38 |
for context in batch_iterate(documents, batch_size=512):
|
|
@@ -43,10 +43,10 @@ def generate_document_embeddings(
|
|
| 43 |
embeds_array = np.array(batch_embeddings)
|
| 44 |
binary_embeds = np.where(embeds_array > 0, 1, 0).astype(np.uint8)
|
| 45 |
|
| 46 |
-
# convert to bytes
|
| 47 |
packed_embeds = np.packbits(binary_embeds, axis=1)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
return binary_embeddings
|
| 51 |
except Exception as e:
|
| 52 |
print(f"Error generating document embeddings: {e}")
|
|
|
|
| 32 |
Returns:
|
| 33 |
A list of document embeddings
|
| 34 |
"""
|
| 35 |
+
binary_embeddings: list[bytes] = []
|
| 36 |
|
| 37 |
try:
|
| 38 |
for context in batch_iterate(documents, batch_size=512):
|
|
|
|
| 43 |
embeds_array = np.array(batch_embeddings)
|
| 44 |
binary_embeds = np.where(embeds_array > 0, 1, 0).astype(np.uint8)
|
| 45 |
|
| 46 |
+
# convert to bytes per vector
|
| 47 |
packed_embeds = np.packbits(binary_embeds, axis=1)
|
| 48 |
+
for row in packed_embeds:
|
| 49 |
+
binary_embeddings.append(row.tobytes())
|
| 50 |
return binary_embeddings
|
| 51 |
except Exception as e:
|
| 52 |
print(f"Error generating document embeddings: {e}")
|
src/rag_pipeline.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
|
|
| 1 |
from langchain.chat_models import init_chat_model
|
| 2 |
from langchain_core.messages import HumanMessage
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from .config import MODEL_NAME, MODEL_PROVIDER, PROMPT, TEMPERATURE
|
| 5 |
|
| 6 |
llm = init_chat_model(
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
from langchain.chat_models import init_chat_model
|
| 3 |
from langchain_core.messages import HumanMessage
|
| 4 |
|
| 5 |
+
load_dotenv(override=True)
|
| 6 |
+
|
| 7 |
from .config import MODEL_NAME, MODEL_PROVIDER, PROMPT, TEMPERATURE
|
| 8 |
|
| 9 |
llm = init_chat_model(
|
src/vector_store.py
CHANGED
|
@@ -32,50 +32,47 @@ def create_collection_if_not_exists(
|
|
| 32 |
dim: The dimension of the binary vector
|
| 33 |
"""
|
| 34 |
try:
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
# Drop collection if it exists
|
| 37 |
-
if client.has_collection(collection_name):
|
| 38 |
-
print(f"Collection {collection_name} exists, dropping it...")
|
| 39 |
-
client.drop_collection(collection_name)
|
| 40 |
-
|
| 41 |
-
# Initialize client
|
| 42 |
schema = client.create_schema(
|
| 43 |
auto_id=True,
|
| 44 |
enable_dynamic_fields=True,
|
| 45 |
)
|
| 46 |
-
# Add primary key field
|
| 47 |
schema.add_field(
|
| 48 |
field_name="id",
|
| 49 |
datatype=DataType.INT64,
|
| 50 |
is_primary=True,
|
| 51 |
auto_id=True,
|
| 52 |
)
|
| 53 |
-
# Add fields to schema
|
| 54 |
schema.add_field(
|
| 55 |
field_name="context",
|
| 56 |
datatype=DataType.VARCHAR,
|
| 57 |
-
max_length=65535,
|
| 58 |
)
|
| 59 |
schema.add_field(
|
| 60 |
field_name="binary_vector",
|
| 61 |
datatype=DataType.BINARY_VECTOR,
|
| 62 |
dim=dim,
|
| 63 |
)
|
| 64 |
-
|
| 65 |
index_params = client.prepare_index_params()
|
| 66 |
index_params.add_index(
|
| 67 |
field_name="binary_vector",
|
| 68 |
index_name="binary_vector_index",
|
| 69 |
-
index_type="BIN_FLAT",
|
| 70 |
-
metric_type="HAMMING",
|
| 71 |
)
|
| 72 |
-
|
| 73 |
client.create_collection(
|
| 74 |
collection_name=collection_name,
|
| 75 |
schema=schema,
|
| 76 |
index_params=index_params,
|
| 77 |
)
|
| 78 |
print(f"Collection {collection_name} created successfully.")
|
|
|
|
|
|
|
| 79 |
except Exception as e:
|
| 80 |
print(f"Error creating collection: {e}")
|
| 81 |
return None
|
|
|
|
| 32 |
dim: The dimension of the binary vector
|
| 33 |
"""
|
| 34 |
try:
|
| 35 |
+
# Create collection only if it does not exist
|
| 36 |
+
if not client.has_collection(collection_name):
|
| 37 |
+
print(f"Collection {collection_name} not found. Creating it...")
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
schema = client.create_schema(
|
| 40 |
auto_id=True,
|
| 41 |
enable_dynamic_fields=True,
|
| 42 |
)
|
|
|
|
| 43 |
schema.add_field(
|
| 44 |
field_name="id",
|
| 45 |
datatype=DataType.INT64,
|
| 46 |
is_primary=True,
|
| 47 |
auto_id=True,
|
| 48 |
)
|
|
|
|
| 49 |
schema.add_field(
|
| 50 |
field_name="context",
|
| 51 |
datatype=DataType.VARCHAR,
|
| 52 |
+
max_length=65535,
|
| 53 |
)
|
| 54 |
schema.add_field(
|
| 55 |
field_name="binary_vector",
|
| 56 |
datatype=DataType.BINARY_VECTOR,
|
| 57 |
dim=dim,
|
| 58 |
)
|
| 59 |
+
|
| 60 |
index_params = client.prepare_index_params()
|
| 61 |
index_params.add_index(
|
| 62 |
field_name="binary_vector",
|
| 63 |
index_name="binary_vector_index",
|
| 64 |
+
index_type="BIN_FLAT",
|
| 65 |
+
metric_type="HAMMING",
|
| 66 |
)
|
| 67 |
+
|
| 68 |
client.create_collection(
|
| 69 |
collection_name=collection_name,
|
| 70 |
schema=schema,
|
| 71 |
index_params=index_params,
|
| 72 |
)
|
| 73 |
print(f"Collection {collection_name} created successfully.")
|
| 74 |
+
else:
|
| 75 |
+
print(f"Collection {collection_name} already exists. Skipping creation.")
|
| 76 |
except Exception as e:
|
| 77 |
print(f"Error creating collection: {e}")
|
| 78 |
return None
|