rbbist commited on
Commit
2b0fbff
·
verified ·
1 Parent(s): c369c5f

Update chromadb_semantic_search_for_dataset.py

Browse files
chromadb_semantic_search_for_dataset.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import sqlite3
2
  import chromadb
3
  from chromadb.utils import embedding_functions
@@ -9,9 +10,9 @@ DB_PATH = "2080_data.db"
9
  CHROMA_COLLECTION_NAME = "my_collection"
10
 
11
  # Truncation / summary settings
12
- MAX_CHUNK_CHARS = 1500 # Reduced for performance
13
- SUMMARY_MAX_LENGTH = 100 # Reduced tokens/words budget
14
- COMBINED_CONTEXT_MAX_CHARS = 1500 # Reduced total chars for answer model
15
 
16
  # --- Load data from SQLite ---
17
  try:
@@ -22,20 +23,16 @@ try:
22
  FROM cases
23
  """)
24
  rows = cursor.fetchall()
25
- print("SQLite rows loaded:", len(rows))
26
- if rows:
27
- print("Sample row:", rows[0])
28
  except sqlite3.Error as e:
29
  print(f"SQLite error: {e}")
30
  raise
31
 
32
- # --- Setup ChromaDB ---
33
  try:
34
  chroma_client = chromadb.Client()
35
  sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
36
  model_name="paraphrase-multilingual-mpnet-base-v2"
37
  )
38
- print("Embedding model loaded:", sentence_transformer_ef is not None)
39
  collection = chroma_client.get_or_create_collection(
40
  name=CHROMA_COLLECTION_NAME,
41
  embedding_function=sentence_transformer_ef
@@ -44,13 +41,14 @@ except Exception as e:
44
  print(f"ChromaDB setup error: {e}")
45
  raise
46
 
47
- # --- Load DB rows into ChromaDB collection ---
48
  documents = []
49
  metadatas = []
50
  ids = []
51
 
52
  for i, row in enumerate(rows):
53
  link, decision_no, year, mudda_type, subject, nibedak, vipakshi, prakaran, thahar = row
 
54
  thahar_text = (thahar or "")[:MAX_CHUNK_CHARS]
55
  prakaran_text = (prakaran or "")[:MAX_CHUNK_CHARS]
56
  case_text = f"{mudda_type} {subject} {nibedak} {vipakshi} {prakaran_text} {thahar_text}"
@@ -67,14 +65,17 @@ for i, row in enumerate(rows):
67
  })
68
  ids.append(str(i))
69
 
 
70
  try:
71
  if len(documents) > 0:
72
  collection.add(documents=documents, metadatas=metadatas, ids=ids)
73
- print("ChromaDB collection size:", collection.count())
74
  except Exception as e:
 
75
  print(f"Warning while adding to ChromaDB: {e}")
76
 
77
- def semantic_search(query: str, n_results: int = 2):
 
78
  """
79
  Returns:
80
  - formatted_text: user-facing Markdown/plaintext summary of top results
@@ -82,9 +83,6 @@ def semantic_search(query: str, n_results: int = 2):
82
  - combined_context: concatenated text of top docs (UNSUMMARIZED, truncated per doc)
83
  """
84
  start = time.time()
85
- if not query.strip():
86
- return "Error: Query cannot be empty.", [], ""
87
-
88
  results = collection.query(
89
  query_texts=[query],
90
  n_results=n_results,
@@ -93,14 +91,11 @@ def semantic_search(query: str, n_results: int = 2):
93
 
94
  docs = results.get("documents", [[]])[0]
95
  metas = results.get("metadatas", [[]])[0]
96
- distances = results.get("distances", [[]])[0]
97
- print("Semantic search results:", len(docs))
98
-
99
- if not docs:
100
- return "No results found for query.", [], ""
101
-
102
  top_docs = []
 
103
  for doc, meta, dist in zip(docs, metas, distances):
 
104
  try:
105
  similarity = 1.0 - float(dist)
106
  except Exception:
@@ -111,11 +106,12 @@ def semantic_search(query: str, n_results: int = 2):
111
  "similarity": similarity
112
  })
113
 
 
114
  lines = []
115
  for i, item in enumerate(top_docs, start=1):
116
  m = item["metadata"]
117
  sim_str = f"{item['similarity']:.4f}" if item["similarity"] is not None else "N/A"
118
- snippet = (item["document"][:300] + "...") if len(item["document"]) > 300 else item["document"]
119
  lines.append(
120
  f"🔹 Case {i}\n"
121
  f" 📄 मुद्दाको किसिम: {m.get('mudda_type','')}\n"
@@ -130,12 +126,24 @@ def semantic_search(query: str, n_results: int = 2):
130
  )
131
 
132
  formatted_text = "\n\n".join(lines)
133
- combined_items = [f"[Case {i}] {item['document'][:MAX_CHUNK_CHARS]}" for i, item in enumerate(top_docs, start=1)]
 
 
 
 
 
 
 
134
  combined_context = "\n\n".join(combined_items)
135
  elapsed = time.time() - start
136
- print("Semantic search elapsed:", elapsed)
137
  return formatted_text, top_docs, combined_context
138
 
 
 
 
 
 
139
  def build_compact_context(summaries: List[str]) -> str:
140
  """
141
  Given a list of per-case summaries, concatenate them while keeping
@@ -145,6 +153,7 @@ def build_compact_context(summaries: List[str]) -> str:
145
  total = 0
146
  for i, s in enumerate(summaries, start=1):
147
  if total + len(s) + 10 > COMBINED_CONTEXT_MAX_CHARS:
 
148
  remaining = COMBINED_CONTEXT_MAX_CHARS - total - 10
149
  if remaining <= 0:
150
  break
 
1
+ # chromadb_semantic_search_for_dataset.py
2
  import sqlite3
3
  import chromadb
4
  from chromadb.utils import embedding_functions
 
10
  CHROMA_COLLECTION_NAME = "my_collection"
11
 
12
  # Truncation / summary settings
13
+ MAX_CHUNK_CHARS = 2000 # truncate each full case to this before summarizing
14
+ SUMMARY_MAX_LENGTH = 150 # tokens/words budget for each per-case summary (ADDED THIS LINE)
15
+ COMBINED_CONTEXT_MAX_CHARS = 3000 # total chars to send to the answer model
16
 
17
  # --- Load data from SQLite ---
18
  try:
 
23
  FROM cases
24
  """)
25
  rows = cursor.fetchall()
 
 
 
26
  except sqlite3.Error as e:
27
  print(f"SQLite error: {e}")
28
  raise
29
 
30
+ # --- Setup ChromaDB (in-memory client; assumes embeddings will be computed at startup) ---
31
  try:
32
  chroma_client = chromadb.Client()
33
  sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
34
  model_name="paraphrase-multilingual-mpnet-base-v2"
35
  )
 
36
  collection = chroma_client.get_or_create_collection(
37
  name=CHROMA_COLLECTION_NAME,
38
  embedding_function=sentence_transformer_ef
 
41
  print(f"ChromaDB setup error: {e}")
42
  raise
43
 
44
+ # --- Load DB rows into ChromaDB collection (one-time per start) ---
45
  documents = []
46
  metadatas = []
47
  ids = []
48
 
49
  for i, row in enumerate(rows):
50
  link, decision_no, year, mudda_type, subject, nibedak, vipakshi, prakaran, thahar = row
51
+ # Build a single text blob for embedding; truncate the large field
52
  thahar_text = (thahar or "")[:MAX_CHUNK_CHARS]
53
  prakaran_text = (prakaran or "")[:MAX_CHUNK_CHARS]
54
  case_text = f"{mudda_type} {subject} {nibedak} {vipakshi} {prakaran_text} {thahar_text}"
 
65
  })
66
  ids.append(str(i))
67
 
68
+ # Add to collection (if collection already has items this may raise duplicates; you can adjust)
69
  try:
70
  if len(documents) > 0:
71
  collection.add(documents=documents, metadatas=metadatas, ids=ids)
72
+ print(f"Added {len(documents)} documents to ChromaDB collection")
73
  except Exception as e:
74
+ # If collection already contains these ids, you may see errors; ignore or handle as needed.
75
  print(f"Warning while adding to ChromaDB: {e}")
76
 
77
+ # --- Semantic search function (returns nicely formatted top N + raw top docs) ---
78
+ def semantic_search(query: str, n_results: int = 3):
79
  """
80
  Returns:
81
  - formatted_text: user-facing Markdown/plaintext summary of top results
 
83
  - combined_context: concatenated text of top docs (UNSUMMARIZED, truncated per doc)
84
  """
85
  start = time.time()
 
 
 
86
  results = collection.query(
87
  query_texts=[query],
88
  n_results=n_results,
 
91
 
92
  docs = results.get("documents", [[]])[0]
93
  metas = results.get("metadatas", [[]])[0]
94
+ distances = results.get("distances", [[]])[0] # distances (Chroma uses 1 - cosine if using cosine)
 
 
 
 
 
95
  top_docs = []
96
+
97
  for doc, meta, dist in zip(docs, metas, distances):
98
+ # Convert distance -> cosine similarity (approx): cosine = 1 - distance
99
  try:
100
  similarity = 1.0 - float(dist)
101
  except Exception:
 
106
  "similarity": similarity
107
  })
108
 
109
+ # Build a formatted summary for display
110
  lines = []
111
  for i, item in enumerate(top_docs, start=1):
112
  m = item["metadata"]
113
  sim_str = f"{item['similarity']:.4f}" if item["similarity"] is not None else "N/A"
114
+ snippet = (item["document"][:400] + "...") if len(item["document"]) > 400 else item["document"]
115
  lines.append(
116
  f"🔹 Case {i}\n"
117
  f" 📄 मुद्दाको किसिम: {m.get('mudda_type','')}\n"
 
126
  )
127
 
128
  formatted_text = "\n\n".join(lines)
129
+
130
+ # Build combined_context (truncated per doc) for summarization/answering
131
+ combined_items = []
132
+ for i, item in enumerate(top_docs, start=1):
133
+ d = item["document"]
134
+ # ensure we don't exceed MAX_CHUNK_CHARS per doc (we already truncated at insertion)
135
+ combined_items.append(f"[Case {i}] {d[:MAX_CHUNK_CHARS]}")
136
+
137
  combined_context = "\n\n".join(combined_items)
138
  elapsed = time.time() - start
139
+ print(f"Semantic search completed in {elapsed:.2f}s")
140
  return formatted_text, top_docs, combined_context
141
 
142
+
143
+ # --- Summarization + RAG preparation ---
144
+ # We'll create summarizer and answerer pipelines in app.py (to avoid TF/torch duplicate loading),
145
+ # but provide helper that trims the combined context to a length budget.
146
+
147
  def build_compact_context(summaries: List[str]) -> str:
148
  """
149
  Given a list of per-case summaries, concatenate them while keeping
 
153
  total = 0
154
  for i, s in enumerate(summaries, start=1):
155
  if total + len(s) + 10 > COMBINED_CONTEXT_MAX_CHARS:
156
+ # take partial from summary if needed
157
  remaining = COMBINED_CONTEXT_MAX_CHARS - total - 10
158
  if remaining <= 0:
159
  break