File size: 5,881 Bytes
900e88e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import chromadb
from sentence_transformers import SentenceTransformer, CrossEncoder

CHROMA_DB_DIR = "./chroma_db"
COLLECTION_NAME = "bangdream_dense"
MODEL_NAME = "BAAI/bge-base-zh-v1.5"

reranker = CrossEncoder("BAAI/bge-reranker-large")


def load_collection(db_path=CHROMA_DB_DIR, collection_name=COLLECTION_NAME):
    """Connect to Chroma persistent DB and load a collection."""
    client = chromadb.PersistentClient(path=db_path)
    collection = client.get_or_create_collection(collection_name)
    return collection

def load_encoder(model_name=MODEL_NAME):
    """Load dense encoder model."""
    return SentenceTransformer(model_name)

def encode_query(encoder, query_text):
    """Encode query text into normalized embedding."""
    return encoder.encode_query([query_text], normalize_embeddings=True)

def dedup_by_chapter_event(reranked_docs, max_per_group=1):
    """de-duplicate when chapterTitle and eventName are identical"""
    seen = {}
    deduped = []
    for doc, score, meta in reranked_docs:
        key = (meta.get("chapterTitle", ""), meta.get("eventName", ""))
        if key not in seen:
            seen[key] = 1
            deduped.append((doc, score, meta))
        elif seen[key] < max_per_group:
            seen[key] += 1
            deduped.append((doc, score, meta))
    return deduped

def retrieve_docs(collection, query_vec, top_k=5):
    """Retrieve documents from Chroma collection."""
    results = collection.query(
        query_embeddings=query_vec,
        n_results=top_k,
        include=["metadatas", "documents", "distances"],
    )
    return results

def query_rerank(reranker, query, results, top_n=3):
    """Use CrossEncoder to re-rank retrieved results."""
    docs = results["documents"][0]
    pairs = [(query, doc) for doc in docs]

    # CrossEncoder
    scores = reranker.predict(pairs)

    # rerank
    ranked = sorted(zip(docs, scores, results["metadatas"][0]), key=lambda x: x[1], reverse=True)

    # get top_n
    reranked_docs = ranked[:top_n]

    """
    # print result
    print("=== After Rerank ===")
    for i, (doc, score, meta) in enumerate(reranked_docs, 1):
        print(f"Rank {i} | Score: {score:.4f}")
        print(meta)
        print(doc)
        print("-" * 40)
    """

    return reranked_docs

def pretty_print_results(results):
    """Nicely print retrieved results."""
    docs = results["documents"][0]
    dists = results["distances"][0]
    metas = results["metadatas"][0]
    for idx, (doc, dist, meta) in enumerate(zip(docs, dists, metas)):
        print(f"Rank {idx + 1} | Distance: {dist:.4f}")
        print(meta)
        print(doc)
        print("-" * 40)

# expend documents
def get_all_chunks_in_chapter(collection, chapter_title, event_name=None, story_type=None):
    filters = []
    if chapter_title:
        filters.append({"chapterTitle": chapter_title})
    if story_type:
        filters.append({"story_type": story_type})
    if event_name:
        filters.append({"eventName": event_name})
    if len(filters) == 1:
        filter_dict = filters[0]
    elif len(filters) > 1:
        filter_dict = {"$and": filters}
    else:
        filter_dict = {}
    results = collection.get(where=filter_dict, include=["documents", "metadatas"])
    chunk_list = []
    for doc, meta in zip(results["documents"], results["metadatas"]):
        chunk_list.append({
            "text": doc,
            **meta,
        })
    return chunk_list

def find_adjacent_chunks(current_chunk, all_chunks):
    start_idx = current_chunk['start_idx']
    end_idx = current_chunk['end_idx']
    prev_chunk, next_chunk = None, None
    for chunk in all_chunks:
        if chunk['end_idx'] == start_idx - 1:
            prev_chunk = chunk
        if chunk['start_idx'] == end_idx + 1:
            next_chunk = chunk
    return prev_chunk, next_chunk

def safe_to_list(x):
    if isinstance(x, str):
        return x.split('\n') if '\n' in x else [x]
    return list(x)

def expand_with_neighbors(reranked_docs, collection):
    expanded_results = []
    for doc, score, meta in reranked_docs:
        #print(meta)
        chapter_title = meta.get("chapterTitle", "")
        event_name = meta.get("eventName", "")
        story_type = meta.get("story_type", None)
        all_chunks = get_all_chunks_in_chapter(collection, chapter_title, event_name, story_type)
        prev_chunk, next_chunk = find_adjacent_chunks(meta, all_chunks)
        expanded_text = []
        if prev_chunk:
            #expanded_text += prev_chunk["text"]
            expanded_text += safe_to_list(prev_chunk["text"])
            #expanded_text.extend(prev_chunk["text"])
        #expanded_text += doc
        expanded_text += safe_to_list(doc)

        #expanded_text.extend(doc if isinstance(doc, list) else [doc])
        if next_chunk:
            #expanded_text.extend(next_chunk["text"])
            #expanded_text += next_chunk["text"]
            expanded_text += safe_to_list(next_chunk["text"])

        expanded_results.append((
            "\n".join(expanded_text),
            score,
            {
                **meta,
                #"prev_chunk_id": prev_chunk["ids"][0] if prev_chunk else None,
                #"next_chunk_id": next_chunk["ids"][0] if next_chunk else None,
            }
        ))
    return expanded_results

"""if __name__ == "__main__":
    collection = load_collection()
    encoder = load_encoder()

    query_text = "乐奈喜欢什么?"
    query_vec = encode_query(encoder, query_text)
    results = retrieve_docs(collection, query_vec, top_k=50)
    reranked = query_rerank(reranker, query_text, results, top_n=20)
    deduped = dedup_by_chapter_event(reranked, max_per_group=1)
    expanded_results = expand_with_neighbors(deduped[:5], collection)

    for doc in expanded_results:
        print("===")
        print(doc)
        print(doc[0])
        print("===")"""