File size: 5,881 Bytes
900e88e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import chromadb
from sentence_transformers import SentenceTransformer, CrossEncoder
CHROMA_DB_DIR = "./chroma_db"
COLLECTION_NAME = "bangdream_dense"
MODEL_NAME = "BAAI/bge-base-zh-v1.5"
reranker = CrossEncoder("BAAI/bge-reranker-large")
def load_collection(db_path=CHROMA_DB_DIR, collection_name=COLLECTION_NAME):
"""Connect to Chroma persistent DB and load a collection."""
client = chromadb.PersistentClient(path=db_path)
collection = client.get_or_create_collection(collection_name)
return collection
def load_encoder(model_name=MODEL_NAME):
"""Load dense encoder model."""
return SentenceTransformer(model_name)
def encode_query(encoder, query_text):
"""Encode query text into normalized embedding."""
return encoder.encode_query([query_text], normalize_embeddings=True)
def dedup_by_chapter_event(reranked_docs, max_per_group=1):
"""de-duplicate when chapterTitle and eventName are identical"""
seen = {}
deduped = []
for doc, score, meta in reranked_docs:
key = (meta.get("chapterTitle", ""), meta.get("eventName", ""))
if key not in seen:
seen[key] = 1
deduped.append((doc, score, meta))
elif seen[key] < max_per_group:
seen[key] += 1
deduped.append((doc, score, meta))
return deduped
def retrieve_docs(collection, query_vec, top_k=5):
"""Retrieve documents from Chroma collection."""
results = collection.query(
query_embeddings=query_vec,
n_results=top_k,
include=["metadatas", "documents", "distances"],
)
return results
def query_rerank(reranker, query, results, top_n=3):
"""Use CrossEncoder to re-rank retrieved results."""
docs = results["documents"][0]
pairs = [(query, doc) for doc in docs]
# CrossEncoder
scores = reranker.predict(pairs)
# rerank
ranked = sorted(zip(docs, scores, results["metadatas"][0]), key=lambda x: x[1], reverse=True)
# get top_n
reranked_docs = ranked[:top_n]
"""
# print result
print("=== After Rerank ===")
for i, (doc, score, meta) in enumerate(reranked_docs, 1):
print(f"Rank {i} | Score: {score:.4f}")
print(meta)
print(doc)
print("-" * 40)
"""
return reranked_docs
def pretty_print_results(results):
"""Nicely print retrieved results."""
docs = results["documents"][0]
dists = results["distances"][0]
metas = results["metadatas"][0]
for idx, (doc, dist, meta) in enumerate(zip(docs, dists, metas)):
print(f"Rank {idx + 1} | Distance: {dist:.4f}")
print(meta)
print(doc)
print("-" * 40)
# expend documents
def get_all_chunks_in_chapter(collection, chapter_title, event_name=None, story_type=None):
filters = []
if chapter_title:
filters.append({"chapterTitle": chapter_title})
if story_type:
filters.append({"story_type": story_type})
if event_name:
filters.append({"eventName": event_name})
if len(filters) == 1:
filter_dict = filters[0]
elif len(filters) > 1:
filter_dict = {"$and": filters}
else:
filter_dict = {}
results = collection.get(where=filter_dict, include=["documents", "metadatas"])
chunk_list = []
for doc, meta in zip(results["documents"], results["metadatas"]):
chunk_list.append({
"text": doc,
**meta,
})
return chunk_list
def find_adjacent_chunks(current_chunk, all_chunks):
start_idx = current_chunk['start_idx']
end_idx = current_chunk['end_idx']
prev_chunk, next_chunk = None, None
for chunk in all_chunks:
if chunk['end_idx'] == start_idx - 1:
prev_chunk = chunk
if chunk['start_idx'] == end_idx + 1:
next_chunk = chunk
return prev_chunk, next_chunk
def safe_to_list(x):
if isinstance(x, str):
return x.split('\n') if '\n' in x else [x]
return list(x)
def expand_with_neighbors(reranked_docs, collection):
expanded_results = []
for doc, score, meta in reranked_docs:
#print(meta)
chapter_title = meta.get("chapterTitle", "")
event_name = meta.get("eventName", "")
story_type = meta.get("story_type", None)
all_chunks = get_all_chunks_in_chapter(collection, chapter_title, event_name, story_type)
prev_chunk, next_chunk = find_adjacent_chunks(meta, all_chunks)
expanded_text = []
if prev_chunk:
#expanded_text += prev_chunk["text"]
expanded_text += safe_to_list(prev_chunk["text"])
#expanded_text.extend(prev_chunk["text"])
#expanded_text += doc
expanded_text += safe_to_list(doc)
#expanded_text.extend(doc if isinstance(doc, list) else [doc])
if next_chunk:
#expanded_text.extend(next_chunk["text"])
#expanded_text += next_chunk["text"]
expanded_text += safe_to_list(next_chunk["text"])
expanded_results.append((
"\n".join(expanded_text),
score,
{
**meta,
#"prev_chunk_id": prev_chunk["ids"][0] if prev_chunk else None,
#"next_chunk_id": next_chunk["ids"][0] if next_chunk else None,
}
))
return expanded_results
"""if __name__ == "__main__":
collection = load_collection()
encoder = load_encoder()
query_text = "乐奈喜欢什么?"
query_vec = encode_query(encoder, query_text)
results = retrieve_docs(collection, query_vec, top_k=50)
reranked = query_rerank(reranker, query_text, results, top_n=20)
deduped = dedup_by_chapter_event(reranked, max_per_group=1)
expanded_results = expand_with_neighbors(deduped[:5], collection)
for doc in expanded_results:
print("===")
print(doc)
print(doc[0])
print("===")"""
|