File size: 2,629 Bytes
0dd2dc1
 
 
 
 
 
 
6234335
0dd2dc1
 
a13a62d
1e384db
6234335
0dd2dc1
6234335
0dd2dc1
 
 
 
 
 
7d8273c
a13a62d
 
 
0dd2dc1
a13a62d
0dd2dc1
 
1e384db
7d8273c
 
 
6234335
7d8273c
 
1e384db
6234335
1e384db
 
 
6234335
 
 
 
 
 
 
 
 
 
 
 
1e384db
 
 
0dd2dc1
6234335
 
0dd2dc1
6234335
 
 
 
0dd2dc1
6234335
0dd2dc1
 
 
6234335
 
0dd2dc1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import sys
import os
from dotenv import load_dotenv

root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
sys.path.append(root_dir)

from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI 
from langchain_classic.retrievers.document_compressors import CrossEncoderReranker
from langchain_core.messages import SystemMessage
from langchain_classic.retrievers import ContextualCompressionRetriever, MultiQueryRetriever
from src.retrieval.vector_store import get_vector_store
from src.retrieval.retriever import get_retriever
from src.chains.prompt import get_rag_prompt

load_dotenv()

def create_rag_chain(disease_label=None):
    llm = ChatGroq(
        model="openai/gpt-oss-20b", 
        streaming=True,
        temperature=0.2,
        api_key=os.getenv("GROQ_API_KEY"),
    ) 
    
    vs = get_vector_store()
    base_retriever = get_retriever(
        vector_store=vs, 
        search_type="similarity", 
        k=5,
        filter_label=disease_label 
    )
    
    mq_retriever = MultiQueryRetriever.from_llm(
        retriever=base_retriever,
        llm=llm
    )

    print("Memuat model Re-ranker (BAAI/bge-reranker-v2-m3)...")
    cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
    
    # Kita hanya ambil 3 dokumen paling relevan (top_n=3) setelah di-rerank
    compressor = CrossEncoderReranker(model=cross_encoder, top_n=3)
    
    # 4. BUNGKUS MENJADI COMPRESSION RETRIEVER
    rerank_retriever = ContextualCompressionRetriever(
        base_compressor=compressor,
        base_retriever=mq_retriever
    )
    
    prompt = get_rag_prompt()
    
    def format_docs(docs):
        print("\n" + "="*50)
        print("🎯 [DEBUG] 3 DOKUMEN TERBAIK SETELAH DI-RERANK:")
        for i, doc in enumerate(docs):
            sumber = doc.metadata.get('label') or doc.metadata.get('source') or 'Sumber tidak diketahui'
            skor = doc.metadata.get('relevance_score', 'N/A')
            print(f"  [{i+1}] Topik: {sumber} | Skor Relevansi: {skor}")
        print("="*50 + "\n")
        
        # Gabungkan teks dokumen final
        return "\n\n".join(doc.page_content for doc in docs)
        
    rag_chain = (
        # Gunakan rerank_retriever sebagai sumber konteks
        {"context": rerank_retriever | format_docs, "input": RunnablePassthrough()} 
        | prompt
        | llm
        | StrOutputParser()
    )
    
    return rag_chain