serverdaun commited on
Commit
2f81d82
·
1 Parent(s): c7a21f9

Implement collection management in Milvus: drop collection during cleanup and on startup if no documents exist. Update embedding generation to clarify data types and improve error handling. Refactor vector store collection creation logic to avoid unnecessary drops.

Browse files
Files changed (4) hide show
  1. app.py +26 -0
  2. src/embedding_generator.py +4 -4
  3. src/rag_pipeline.py +3 -0
  4. src/vector_store.py +10 -13
app.py CHANGED
@@ -38,6 +38,13 @@ def cleanup_documents():
38
  for f in files:
39
  if os.path.isfile(f):
40
  os.remove(f)
 
 
 
 
 
 
 
41
  print("Cleanup complete.")
42
 
43
 
@@ -45,6 +52,23 @@ def cleanup_documents():
45
  atexit.register(cleanup_documents)
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def index_documents(file_list):
49
  """Index documents from a list of files."""
50
  if not file_list:
@@ -113,4 +137,6 @@ with gr.Blocks() as demo:
113
  if __name__ == "__main__":
114
  # Ensure the documents directory exists from the start
115
  os.makedirs(DOCS_DIR, exist_ok=True)
 
 
116
  demo.launch()
 
38
  for f in files:
39
  if os.path.isfile(f):
40
  os.remove(f)
41
+ # Also drop the Milvus collection to avoid stale state between restarts
42
+ try:
43
+ if milvus_client and milvus_client.has_collection(COLLECTION_NAME):
44
+ milvus_client.drop_collection(COLLECTION_NAME)
45
+ print(f"Dropped collection {COLLECTION_NAME} during cleanup.")
46
+ except Exception as e:
47
+ print(f"Error dropping collection during cleanup: {e}")
48
  print("Cleanup complete.")
49
 
50
 
 
52
  atexit.register(cleanup_documents)
53
 
54
 
55
+ def reset_collection_if_no_docs():
56
+ """Drop existing collection on startup if there are no documents on disk."""
57
+ try:
58
+ os.makedirs(DOCS_DIR, exist_ok=True)
59
+ files = glob.glob(os.path.join(DOCS_DIR, "*"))
60
+ has_docs = any(os.path.isfile(f) for f in files)
61
+ if (
62
+ not has_docs
63
+ and milvus_client
64
+ and milvus_client.has_collection(COLLECTION_NAME)
65
+ ):
66
+ milvus_client.drop_collection(COLLECTION_NAME)
67
+ print(f"No documents found. Dropped existing collection {COLLECTION_NAME}.")
68
+ except Exception as e:
69
+ print(f"Error resetting collection on startup: {e}")
70
+
71
+
72
  def index_documents(file_list):
73
  """Index documents from a list of files."""
74
  if not file_list:
 
137
  if __name__ == "__main__":
138
  # Ensure the documents directory exists from the start
139
  os.makedirs(DOCS_DIR, exist_ok=True)
140
+ # Reset collection state if there are no documents at startup
141
+ reset_collection_if_no_docs()
142
  demo.launch()
src/embedding_generator.py CHANGED
@@ -32,7 +32,7 @@ def generate_document_embeddings(
32
  Returns:
33
  A list of document embeddings
34
  """
35
- binary_embeddings = []
36
 
37
  try:
38
  for context in batch_iterate(documents, batch_size=512):
@@ -43,10 +43,10 @@ def generate_document_embeddings(
43
  embeds_array = np.array(batch_embeddings)
44
  binary_embeds = np.where(embeds_array > 0, 1, 0).astype(np.uint8)
45
 
46
- # convert to bytes array
47
  packed_embeds = np.packbits(binary_embeds, axis=1)
48
-
49
- binary_embeddings.extend(packed_embeds)
50
  return binary_embeddings
51
  except Exception as e:
52
  print(f"Error generating document embeddings: {e}")
 
32
  Returns:
33
  A list of document embeddings
34
  """
35
+ binary_embeddings: list[bytes] = []
36
 
37
  try:
38
  for context in batch_iterate(documents, batch_size=512):
 
43
  embeds_array = np.array(batch_embeddings)
44
  binary_embeds = np.where(embeds_array > 0, 1, 0).astype(np.uint8)
45
 
46
+ # convert to bytes per vector
47
  packed_embeds = np.packbits(binary_embeds, axis=1)
48
+ for row in packed_embeds:
49
+ binary_embeddings.append(row.tobytes())
50
  return binary_embeddings
51
  except Exception as e:
52
  print(f"Error generating document embeddings: {e}")
src/rag_pipeline.py CHANGED
@@ -1,6 +1,9 @@
 
1
  from langchain.chat_models import init_chat_model
2
  from langchain_core.messages import HumanMessage
3
 
 
 
4
  from .config import MODEL_NAME, MODEL_PROVIDER, PROMPT, TEMPERATURE
5
 
6
  llm = init_chat_model(
 
1
+ from dotenv import load_dotenv
2
  from langchain.chat_models import init_chat_model
3
  from langchain_core.messages import HumanMessage
4
 
5
+ load_dotenv(override=True)
6
+
7
  from .config import MODEL_NAME, MODEL_PROVIDER, PROMPT, TEMPERATURE
8
 
9
  llm = init_chat_model(
src/vector_store.py CHANGED
@@ -32,50 +32,47 @@ def create_collection_if_not_exists(
32
  dim: The dimension of the binary vector
33
  """
34
  try:
 
 
 
35
 
36
- # Drop collection if it exists
37
- if client.has_collection(collection_name):
38
- print(f"Collection {collection_name} exists, dropping it...")
39
- client.drop_collection(collection_name)
40
-
41
- # Initialize client
42
  schema = client.create_schema(
43
  auto_id=True,
44
  enable_dynamic_fields=True,
45
  )
46
- # Add primary key field
47
  schema.add_field(
48
  field_name="id",
49
  datatype=DataType.INT64,
50
  is_primary=True,
51
  auto_id=True,
52
  )
53
- # Add fields to schema
54
  schema.add_field(
55
  field_name="context",
56
  datatype=DataType.VARCHAR,
57
- max_length=65535, # max length for VARCHAR
58
  )
59
  schema.add_field(
60
  field_name="binary_vector",
61
  datatype=DataType.BINARY_VECTOR,
62
  dim=dim,
63
  )
64
- # Create index params for binary vector
65
  index_params = client.prepare_index_params()
66
  index_params.add_index(
67
  field_name="binary_vector",
68
  index_name="binary_vector_index",
69
- index_type="BIN_FLAT", # Exact search for binary vectors
70
- metric_type="HAMMING", # Hamming distance for binary vectors
71
  )
72
- # Create collection with schema and index
73
  client.create_collection(
74
  collection_name=collection_name,
75
  schema=schema,
76
  index_params=index_params,
77
  )
78
  print(f"Collection {collection_name} created successfully.")
 
 
79
  except Exception as e:
80
  print(f"Error creating collection: {e}")
81
  return None
 
32
  dim: The dimension of the binary vector
33
  """
34
  try:
35
+ # Create collection only if it does not exist
36
+ if not client.has_collection(collection_name):
37
+ print(f"Collection {collection_name} not found. Creating it...")
38
 
 
 
 
 
 
 
39
  schema = client.create_schema(
40
  auto_id=True,
41
  enable_dynamic_fields=True,
42
  )
 
43
  schema.add_field(
44
  field_name="id",
45
  datatype=DataType.INT64,
46
  is_primary=True,
47
  auto_id=True,
48
  )
 
49
  schema.add_field(
50
  field_name="context",
51
  datatype=DataType.VARCHAR,
52
+ max_length=65535,
53
  )
54
  schema.add_field(
55
  field_name="binary_vector",
56
  datatype=DataType.BINARY_VECTOR,
57
  dim=dim,
58
  )
59
+
60
  index_params = client.prepare_index_params()
61
  index_params.add_index(
62
  field_name="binary_vector",
63
  index_name="binary_vector_index",
64
+ index_type="BIN_FLAT",
65
+ metric_type="HAMMING",
66
  )
67
+
68
  client.create_collection(
69
  collection_name=collection_name,
70
  schema=schema,
71
  index_params=index_params,
72
  )
73
  print(f"Collection {collection_name} created successfully.")
74
+ else:
75
+ print(f"Collection {collection_name} already exists. Skipping creation.")
76
  except Exception as e:
77
  print(f"Error creating collection: {e}")
78
  return None