Update chromadb_semantic_search_for_dataset.py
Browse files
chromadb_semantic_search_for_dataset.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import sqlite3
|
| 2 |
import chromadb
|
| 3 |
from chromadb.utils import embedding_functions
|
|
@@ -9,9 +10,9 @@ DB_PATH = "2080_data.db"
|
|
| 9 |
CHROMA_COLLECTION_NAME = "my_collection"
|
| 10 |
|
| 11 |
# Truncation / summary settings
|
| 12 |
-
MAX_CHUNK_CHARS =
|
| 13 |
-
SUMMARY_MAX_LENGTH =
|
| 14 |
-
COMBINED_CONTEXT_MAX_CHARS =
|
| 15 |
|
| 16 |
# --- Load data from SQLite ---
|
| 17 |
try:
|
|
@@ -22,20 +23,16 @@ try:
|
|
| 22 |
FROM cases
|
| 23 |
""")
|
| 24 |
rows = cursor.fetchall()
|
| 25 |
-
print("SQLite rows loaded:", len(rows))
|
| 26 |
-
if rows:
|
| 27 |
-
print("Sample row:", rows[0])
|
| 28 |
except sqlite3.Error as e:
|
| 29 |
print(f"SQLite error: {e}")
|
| 30 |
raise
|
| 31 |
|
| 32 |
-
# --- Setup ChromaDB ---
|
| 33 |
try:
|
| 34 |
chroma_client = chromadb.Client()
|
| 35 |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| 36 |
model_name="paraphrase-multilingual-mpnet-base-v2"
|
| 37 |
)
|
| 38 |
-
print("Embedding model loaded:", sentence_transformer_ef is not None)
|
| 39 |
collection = chroma_client.get_or_create_collection(
|
| 40 |
name=CHROMA_COLLECTION_NAME,
|
| 41 |
embedding_function=sentence_transformer_ef
|
|
@@ -44,13 +41,14 @@ except Exception as e:
|
|
| 44 |
print(f"ChromaDB setup error: {e}")
|
| 45 |
raise
|
| 46 |
|
| 47 |
-
# --- Load DB rows into ChromaDB collection ---
|
| 48 |
documents = []
|
| 49 |
metadatas = []
|
| 50 |
ids = []
|
| 51 |
|
| 52 |
for i, row in enumerate(rows):
|
| 53 |
link, decision_no, year, mudda_type, subject, nibedak, vipakshi, prakaran, thahar = row
|
|
|
|
| 54 |
thahar_text = (thahar or "")[:MAX_CHUNK_CHARS]
|
| 55 |
prakaran_text = (prakaran or "")[:MAX_CHUNK_CHARS]
|
| 56 |
case_text = f"{mudda_type} {subject} {nibedak} {vipakshi} {prakaran_text} {thahar_text}"
|
|
@@ -67,14 +65,17 @@ for i, row in enumerate(rows):
|
|
| 67 |
})
|
| 68 |
ids.append(str(i))
|
| 69 |
|
|
|
|
| 70 |
try:
|
| 71 |
if len(documents) > 0:
|
| 72 |
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
| 73 |
-
print("
|
| 74 |
except Exception as e:
|
|
|
|
| 75 |
print(f"Warning while adding to ChromaDB: {e}")
|
| 76 |
|
| 77 |
-
|
|
|
|
| 78 |
"""
|
| 79 |
Returns:
|
| 80 |
- formatted_text: user-facing Markdown/plaintext summary of top results
|
|
@@ -82,9 +83,6 @@ def semantic_search(query: str, n_results: int = 2):
|
|
| 82 |
- combined_context: concatenated text of top docs (UNSUMMARIZED, truncated per doc)
|
| 83 |
"""
|
| 84 |
start = time.time()
|
| 85 |
-
if not query.strip():
|
| 86 |
-
return "Error: Query cannot be empty.", [], ""
|
| 87 |
-
|
| 88 |
results = collection.query(
|
| 89 |
query_texts=[query],
|
| 90 |
n_results=n_results,
|
|
@@ -93,14 +91,11 @@ def semantic_search(query: str, n_results: int = 2):
|
|
| 93 |
|
| 94 |
docs = results.get("documents", [[]])[0]
|
| 95 |
metas = results.get("metadatas", [[]])[0]
|
| 96 |
-
distances = results.get("distances", [[]])[0]
|
| 97 |
-
print("Semantic search results:", len(docs))
|
| 98 |
-
|
| 99 |
-
if not docs:
|
| 100 |
-
return "No results found for query.", [], ""
|
| 101 |
-
|
| 102 |
top_docs = []
|
|
|
|
| 103 |
for doc, meta, dist in zip(docs, metas, distances):
|
|
|
|
| 104 |
try:
|
| 105 |
similarity = 1.0 - float(dist)
|
| 106 |
except Exception:
|
|
@@ -111,11 +106,12 @@ def semantic_search(query: str, n_results: int = 2):
|
|
| 111 |
"similarity": similarity
|
| 112 |
})
|
| 113 |
|
|
|
|
| 114 |
lines = []
|
| 115 |
for i, item in enumerate(top_docs, start=1):
|
| 116 |
m = item["metadata"]
|
| 117 |
sim_str = f"{item['similarity']:.4f}" if item["similarity"] is not None else "N/A"
|
| 118 |
-
snippet = (item["document"][:
|
| 119 |
lines.append(
|
| 120 |
f"🔹 Case {i}\n"
|
| 121 |
f" 📄 मुद्दाको किसिम: {m.get('mudda_type','')}\n"
|
|
@@ -130,12 +126,24 @@ def semantic_search(query: str, n_results: int = 2):
|
|
| 130 |
)
|
| 131 |
|
| 132 |
formatted_text = "\n\n".join(lines)
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
combined_context = "\n\n".join(combined_items)
|
| 135 |
elapsed = time.time() - start
|
| 136 |
-
print("Semantic search elapsed:"
|
| 137 |
return formatted_text, top_docs, combined_context
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
def build_compact_context(summaries: List[str]) -> str:
|
| 140 |
"""
|
| 141 |
Given a list of per-case summaries, concatenate them while keeping
|
|
@@ -145,6 +153,7 @@ def build_compact_context(summaries: List[str]) -> str:
|
|
| 145 |
total = 0
|
| 146 |
for i, s in enumerate(summaries, start=1):
|
| 147 |
if total + len(s) + 10 > COMBINED_CONTEXT_MAX_CHARS:
|
|
|
|
| 148 |
remaining = COMBINED_CONTEXT_MAX_CHARS - total - 10
|
| 149 |
if remaining <= 0:
|
| 150 |
break
|
|
|
|
| 1 |
+
# chromadb_semantic_search_for_dataset.py
|
| 2 |
import sqlite3
|
| 3 |
import chromadb
|
| 4 |
from chromadb.utils import embedding_functions
|
|
|
|
| 10 |
CHROMA_COLLECTION_NAME = "my_collection"
|
| 11 |
|
| 12 |
# Truncation / summary settings
|
| 13 |
+
MAX_CHUNK_CHARS = 2000 # truncate each full case to this before summarizing
|
| 14 |
+
SUMMARY_MAX_LENGTH = 150 # tokens/words budget for each per-case summary (ADDED THIS LINE)
|
| 15 |
+
COMBINED_CONTEXT_MAX_CHARS = 3000 # total chars to send to the answer model
|
| 16 |
|
| 17 |
# --- Load data from SQLite ---
|
| 18 |
try:
|
|
|
|
| 23 |
FROM cases
|
| 24 |
""")
|
| 25 |
rows = cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
| 26 |
except sqlite3.Error as e:
|
| 27 |
print(f"SQLite error: {e}")
|
| 28 |
raise
|
| 29 |
|
| 30 |
+
# --- Setup ChromaDB (in-memory client; assumes embeddings will be computed at startup) ---
|
| 31 |
try:
|
| 32 |
chroma_client = chromadb.Client()
|
| 33 |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| 34 |
model_name="paraphrase-multilingual-mpnet-base-v2"
|
| 35 |
)
|
|
|
|
| 36 |
collection = chroma_client.get_or_create_collection(
|
| 37 |
name=CHROMA_COLLECTION_NAME,
|
| 38 |
embedding_function=sentence_transformer_ef
|
|
|
|
| 41 |
print(f"ChromaDB setup error: {e}")
|
| 42 |
raise
|
| 43 |
|
| 44 |
+
# --- Load DB rows into ChromaDB collection (one-time per start) ---
|
| 45 |
documents = []
|
| 46 |
metadatas = []
|
| 47 |
ids = []
|
| 48 |
|
| 49 |
for i, row in enumerate(rows):
|
| 50 |
link, decision_no, year, mudda_type, subject, nibedak, vipakshi, prakaran, thahar = row
|
| 51 |
+
# Build a single text blob for embedding; truncate the large field
|
| 52 |
thahar_text = (thahar or "")[:MAX_CHUNK_CHARS]
|
| 53 |
prakaran_text = (prakaran or "")[:MAX_CHUNK_CHARS]
|
| 54 |
case_text = f"{mudda_type} {subject} {nibedak} {vipakshi} {prakaran_text} {thahar_text}"
|
|
|
|
| 65 |
})
|
| 66 |
ids.append(str(i))
|
| 67 |
|
| 68 |
+
# Add to collection (if collection already has items this may raise duplicates; you can adjust)
|
| 69 |
try:
|
| 70 |
if len(documents) > 0:
|
| 71 |
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
| 72 |
+
print(f"Added {len(documents)} documents to ChromaDB collection")
|
| 73 |
except Exception as e:
|
| 74 |
+
# If collection already contains these ids, you may see errors; ignore or handle as needed.
|
| 75 |
print(f"Warning while adding to ChromaDB: {e}")
|
| 76 |
|
| 77 |
+
# --- Semantic search function (returns nicely formatted top N + raw top docs) ---
|
| 78 |
+
def semantic_search(query: str, n_results: int = 3):
|
| 79 |
"""
|
| 80 |
Returns:
|
| 81 |
- formatted_text: user-facing Markdown/plaintext summary of top results
|
|
|
|
| 83 |
- combined_context: concatenated text of top docs (UNSUMMARIZED, truncated per doc)
|
| 84 |
"""
|
| 85 |
start = time.time()
|
|
|
|
|
|
|
|
|
|
| 86 |
results = collection.query(
|
| 87 |
query_texts=[query],
|
| 88 |
n_results=n_results,
|
|
|
|
| 91 |
|
| 92 |
docs = results.get("documents", [[]])[0]
|
| 93 |
metas = results.get("metadatas", [[]])[0]
|
| 94 |
+
distances = results.get("distances", [[]])[0] # distances (Chroma uses 1 - cosine if using cosine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
top_docs = []
|
| 96 |
+
|
| 97 |
for doc, meta, dist in zip(docs, metas, distances):
|
| 98 |
+
# Convert distance -> cosine similarity (approx): cosine = 1 - distance
|
| 99 |
try:
|
| 100 |
similarity = 1.0 - float(dist)
|
| 101 |
except Exception:
|
|
|
|
| 106 |
"similarity": similarity
|
| 107 |
})
|
| 108 |
|
| 109 |
+
# Build a formatted summary for display
|
| 110 |
lines = []
|
| 111 |
for i, item in enumerate(top_docs, start=1):
|
| 112 |
m = item["metadata"]
|
| 113 |
sim_str = f"{item['similarity']:.4f}" if item["similarity"] is not None else "N/A"
|
| 114 |
+
snippet = (item["document"][:400] + "...") if len(item["document"]) > 400 else item["document"]
|
| 115 |
lines.append(
|
| 116 |
f"🔹 Case {i}\n"
|
| 117 |
f" 📄 मुद्दाको किसिम: {m.get('mudda_type','')}\n"
|
|
|
|
| 126 |
)
|
| 127 |
|
| 128 |
formatted_text = "\n\n".join(lines)
|
| 129 |
+
|
| 130 |
+
# Build combined_context (truncated per doc) for summarization/answering
|
| 131 |
+
combined_items = []
|
| 132 |
+
for i, item in enumerate(top_docs, start=1):
|
| 133 |
+
d = item["document"]
|
| 134 |
+
# ensure we don't exceed MAX_CHUNK_CHARS per doc (we already truncated at insertion)
|
| 135 |
+
combined_items.append(f"[Case {i}] {d[:MAX_CHUNK_CHARS]}")
|
| 136 |
+
|
| 137 |
combined_context = "\n\n".join(combined_items)
|
| 138 |
elapsed = time.time() - start
|
| 139 |
+
print(f"Semantic search completed in {elapsed:.2f}s")
|
| 140 |
return formatted_text, top_docs, combined_context
|
| 141 |
|
| 142 |
+
|
| 143 |
+
# --- Summarization + RAG preparation ---
|
| 144 |
+
# We'll create summarizer and answerer pipelines in app.py (to avoid TF/torch duplicate loading),
|
| 145 |
+
# but provide helper that trims the combined context to a length budget.
|
| 146 |
+
|
| 147 |
def build_compact_context(summaries: List[str]) -> str:
|
| 148 |
"""
|
| 149 |
Given a list of per-case summaries, concatenate them while keeping
|
|
|
|
| 153 |
total = 0
|
| 154 |
for i, s in enumerate(summaries, start=1):
|
| 155 |
if total + len(s) + 10 > COMBINED_CONTEXT_MAX_CHARS:
|
| 156 |
+
# take partial from summary if needed
|
| 157 |
remaining = COMBINED_CONTEXT_MAX_CHARS - total - 10
|
| 158 |
if remaining <= 0:
|
| 159 |
break
|