Spaces:
Sleeping
Sleeping
File size: 3,489 Bytes
9dadf57 bd4d787 9dadf57 4b62abe bd4d787 9dadf57 7d41565 9dadf57 7d41565 9dadf57 7d41565 865167d 7d41565 9dadf57 7d41565 9dadf57 865167d 7d41565 865167d 9dadf57 7d41565 865167d 9dadf57 7d41565 9dadf57 7d41565 9dadf57 2068d15 5242310 2068d15 5242310 2068d15 5242310 2068d15 5242310 2068d15 5242310 2068d15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import os
import json
import torch
from pathlib import Path
from dotenv import load_dotenv
from operator import itemgetter
from typing import List, Dict
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from huggingface_hub import InferenceClient
from src.retrieval.reranker import HybridReranker
from src.generation.prompt_templates import RAG_PROMPT_TEMPLATE
load_dotenv()
PROJECT_ROOT = Path(__file__).resolve().parents[1]
VECTOR_STORE_PATH = str(PROJECT_ROOT / "data" / "vector_store_faiss")
EMBEDDING_MODEL_NAME = "BAAI/bge-m3"
LLM_REPO_ID = os.getenv("HUGGINGFACE_MODEL", "mistralai/Mixtral-8x7B-Instruct-v0.1")
HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
if not HF_TOKEN:
raise ValueError("token api not found")
client = InferenceClient(model=LLM_REPO_ID, token=HF_TOKEN)
prompt_template = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
def format_docs(docs: List[Dict]) -> str:
return "\n\n".join(doc.page_content for doc in docs)
def generate_answer_from_context(input_dict: Dict) -> str:
context_docs = input_dict["context"]
query_text = input_dict["query"]
formatted_context = format_docs(context_docs)
prompt_value = prompt_template.invoke({
"context": formatted_context,
"query": query_text
})
final_prompt_text = str(prompt_value)
try:
response = client.chat_completion(
messages=[{"role": "user", "content": final_prompt_text}],
max_tokens=300,
temperature=0.1
)
raw_answer = response.choices[0].message.content
clean_answer = raw_answer.strip()
if clean_answer.startswith('text="') and clean_answer.endswith('"'):
clean_answer = clean_answer[6:-1]
elif clean_answer.startswith("text='") and clean_answer.endswith("'"):
clean_answer = clean_answer[6:-1]
if clean_answer.startswith("Resposta:"):
clean_answer = clean_answer.split("Resposta:", 1)[1]
stop_phrases = ["<PERGUNTA>", "Pergunta:"]
for phrase in stop_phrases:
if phrase in clean_answer:
clean_answer = clean_answer.split(phrase)[0]
return clean_answer.strip()
except Exception as e:
print(f"error to call Huggingface API: {e}")
return f"error to call llm: {e}"
def get_rag_chain():
cache_dir = "/app/huggingface_cache"
device = "cuda" if torch.cuda.is_available() else "cpu"
embeddings_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={'device': device},
encode_kwargs={'normalize_embeddings': True},
cache_folder=cache_dir
)
vector_store = FAISS.load_local(
VECTOR_STORE_PATH, embeddings_model, allow_dangerous_deserialization=True
)
hybrid_reranker = HybridReranker(vector_store=vector_store, device=device, cache_dir=cache_dir)
retrieval_chain = lambda query: hybrid_reranker.retrieve_and_rerank(query)
rag_chain = {
"context": retrieval_chain,
"query": RunnablePassthrough()
} | RunnableParallel({
"source_chunks": itemgetter("context"),
"answer": generate_answer_from_context
})
print("pipeline ready")
return rag_chain |