File size: 3,489 Bytes
9dadf57
 
 
 
 
 
 
 
 
 
 
 
 
bd4d787
9dadf57
 
 
 
 
 
 
 
 
 
 
 
 
4b62abe
bd4d787
9dadf57
 
 
 
 
 
 
 
7d41565
 
 
 
 
9dadf57
7d41565
9dadf57
7d41565
 
 
865167d
7d41565
 
9dadf57
7d41565
9dadf57
865167d
7d41565
 
 
 
865167d
9dadf57
 
7d41565
865167d
 
 
 
 
9dadf57
7d41565
9dadf57
7d41565
9dadf57
 
2068d15
 
5242310
 
 
2068d15
 
5242310
 
 
2068d15
5242310
2068d15
 
 
 
5242310
 
 
2068d15
 
 
 
 
 
 
 
 
5242310
2068d15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import json
import torch
from pathlib import Path
from dotenv import load_dotenv
from operator import itemgetter
from typing import List, Dict
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from huggingface_hub import InferenceClient
from src.retrieval.reranker import HybridReranker
from src.generation.prompt_templates import RAG_PROMPT_TEMPLATE

load_dotenv()

PROJECT_ROOT = Path(__file__).resolve().parents[1]
VECTOR_STORE_PATH = str(PROJECT_ROOT / "data" / "vector_store_faiss")
EMBEDDING_MODEL_NAME = "BAAI/bge-m3"
LLM_REPO_ID = os.getenv("HUGGINGFACE_MODEL", "mistralai/Mixtral-8x7B-Instruct-v0.1")
HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")

if not HF_TOKEN:
    raise ValueError("token api not found")

client = InferenceClient(model=LLM_REPO_ID, token=HF_TOKEN)

prompt_template = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)

def format_docs(docs: List[Dict]) -> str:
    return "\n\n".join(doc.page_content for doc in docs)

def generate_answer_from_context(input_dict: Dict) -> str:
    context_docs = input_dict["context"]
    query_text = input_dict["query"]
    formatted_context = format_docs(context_docs)
    
    prompt_value = prompt_template.invoke({
        "context": formatted_context,
        "query": query_text
    })
    final_prompt_text = str(prompt_value)
    
    try:
        response = client.chat_completion(
            messages=[{"role": "user", "content": final_prompt_text}],
            max_tokens=300,
            temperature=0.1
        )
        
        raw_answer = response.choices[0].message.content
        
        clean_answer = raw_answer.strip()
        
        if clean_answer.startswith('text="') and clean_answer.endswith('"'):
            clean_answer = clean_answer[6:-1]
        elif clean_answer.startswith("text='") and clean_answer.endswith("'"):
            clean_answer = clean_answer[6:-1]

        if clean_answer.startswith("Resposta:"):
            clean_answer = clean_answer.split("Resposta:", 1)[1]
        
        stop_phrases = ["<PERGUNTA>", "Pergunta:"]
        for phrase in stop_phrases:
            if phrase in clean_answer:
                clean_answer = clean_answer.split(phrase)[0]

        return clean_answer.strip()

    except Exception as e:
        print(f"error to call Huggingface API: {e}")
        return f"error to call llm: {e}"

def get_rag_chain():
    
    cache_dir = "/app/huggingface_cache"
    device = "cuda" if torch.cuda.is_available() else "cpu"

    embeddings_model = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_NAME,
        model_kwargs={'device': device},
        encode_kwargs={'normalize_embeddings': True},
        cache_folder=cache_dir
    )
    
    vector_store = FAISS.load_local(
        VECTOR_STORE_PATH, embeddings_model, allow_dangerous_deserialization=True
    )
    
    hybrid_reranker = HybridReranker(vector_store=vector_store, device=device, cache_dir=cache_dir)
        
    retrieval_chain = lambda query: hybrid_reranker.retrieve_and_rerank(query)

    rag_chain = {
        "context": retrieval_chain,
        "query": RunnablePassthrough()
    } | RunnableParallel({
        "source_chunks": itemgetter("context"),
        "answer": generate_answer_from_context
    })

    print("pipeline ready")
    return rag_chain