File size: 2,255 Bytes
900e88e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from openai import OpenAI
from dotenv import load_dotenv
import os
from retriever import load_encoder, load_collection, encode_query, retrieve_docs, query_rerank, expand_with_neighbors, dedup_by_chapter_event
from sentence_transformers import CrossEncoder

os.environ["TOKENIZERS_PARALLELISM"] = "false"
# load llm api key in .env
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")

client = OpenAI(api_key=api_key)

def build_rag_prompt(query, context):
    prompt = f"""已知资料如下:
{context}

用户提问:{query}
请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.如果不能确定答案,请如实说明理由,不要凭空编造。"""
    return prompt

def llm_answer(query, expanded_results, model_name="gpt-4o"):
    context = expanded_results[0][0] if expanded_results else ""
    prompt = build_rag_prompt(query, context)
    response = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"},
            {"role": "user", "content": prompt}
        ],
        temperature=0.2,
        max_tokens=512,
    )
    return response.choices[0].message.content.strip()

if __name__ == "__main__":
    collection = load_collection()
    encoder = load_encoder()
    reranker = CrossEncoder("BAAI/bge-reranker-large")

    query_text = input("please enter your question:")
    print("Thinking...\n...")
    query_vec = encode_query(encoder, query_text)
    results = retrieve_docs(collection, query_vec, top_k=50)
    reranked = query_rerank(reranker, query_text, results, top_n=20)
    deduped = dedup_by_chapter_event(reranked, max_per_group=1)
    expanded_results = expand_with_neighbors(deduped[:5], collection)

    answer = llm_answer(query_text, expanded_results)

    print("\n=== Answer ===")
    print(answer)
    print("\n=== retrieved documents ===")
    for idx, (context, score, meta) in enumerate(expanded_results, 1):
        print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
        print(meta)