Spaces:
Sleeping
Sleeping
change reranker
Browse files- rag_app/rag_2.py +41 -10
rag_app/rag_2.py
CHANGED
|
@@ -9,13 +9,39 @@ from llama_index.core.query_engine import RetrieverQueryEngine
|
|
| 9 |
from llama_index.core import StorageContext, load_index_from_storage
|
| 10 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 11 |
from llama_index.core.postprocessor import LLMRerank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
llm = LlamaCPP(
|
| 14 |
model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
|
| 15 |
temperature=0.1,
|
| 16 |
max_new_tokens=256,
|
| 17 |
-
context_window=16384
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
embedding_model = HuggingFaceEmbedding(
|
| 20 |
model_name="models/all-MiniLM-L6-v2"
|
| 21 |
)
|
|
@@ -34,11 +60,15 @@ def check_if_exists():
|
|
| 34 |
|
| 35 |
def precompute_index(data_folder='data'):
|
| 36 |
documents = SimpleDirectoryReader(data_folder).load_data()
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
index.storage_context.persist(persist_dir='models/precomputed_index')
|
| 39 |
bm25_retriever = BM25Retriever.from_defaults(
|
| 40 |
-
nodes=
|
| 41 |
-
similarity_top_k=5
|
|
|
|
| 42 |
)
|
| 43 |
bm25_retriever.persist("models/bm25_retriever")
|
| 44 |
|
|
@@ -56,20 +86,21 @@ def answer_question(query):
|
|
| 56 |
|
| 57 |
retriever = QueryFusionRetriever(
|
| 58 |
[
|
| 59 |
-
index.as_retriever(similarity_top_k=5),
|
| 60 |
bm25_retriever,
|
| 61 |
],
|
| 62 |
llm=llm,
|
| 63 |
num_queries=1,
|
| 64 |
similarity_top_k=5,
|
|
|
|
| 65 |
)
|
| 66 |
-
reranker =
|
| 67 |
-
|
| 68 |
-
top_n=5
|
| 69 |
)
|
| 70 |
keyword_query_engine = RetrieverQueryEngine(
|
| 71 |
retriever=retriever,
|
| 72 |
-
node_postprocessors=[reranker]
|
| 73 |
)
|
| 74 |
|
| 75 |
if is_harmful(query):
|
|
|
|
| 9 |
from llama_index.core import StorageContext, load_index_from_storage
|
| 10 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 11 |
from llama_index.core.postprocessor import LLMRerank
|
| 12 |
+
from llama_index.core.node_parser import TokenTextSplitter
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
| 15 |
+
|
| 16 |
+
_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def messages_to_prompt(messages):
|
| 20 |
+
messages = [{"role": m.role.value, "content": m.content} for m in messages]
|
| 21 |
+
prompt = _tokenizer.apply_chat_template(
|
| 22 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 23 |
+
)
|
| 24 |
+
return prompt
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def completion_to_prompt(completion):
|
| 28 |
+
messages = [{"role": "user", "content": completion}]
|
| 29 |
+
prompt = _tokenizer.apply_chat_template(
|
| 30 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 31 |
+
)
|
| 32 |
+
return prompt
|
| 33 |
+
|
| 34 |
|
| 35 |
llm = LlamaCPP(
|
| 36 |
model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
|
| 37 |
temperature=0.1,
|
| 38 |
max_new_tokens=256,
|
| 39 |
+
context_window=16384,
|
| 40 |
+
model_kwargs={"n_gpu_layers":-1},
|
| 41 |
+
messages_to_prompt=messages_to_prompt,
|
| 42 |
+
completion_to_prompt=completion_to_prompt)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
embedding_model = HuggingFaceEmbedding(
|
| 46 |
model_name="models/all-MiniLM-L6-v2"
|
| 47 |
)
|
|
|
|
| 60 |
|
| 61 |
def precompute_index(data_folder='data'):
|
| 62 |
documents = SimpleDirectoryReader(data_folder).load_data()
|
| 63 |
+
splitter = TokenTextSplitter(chunk_size=400, chunk_overlap=50)
|
| 64 |
+
nodes = splitter.get_nodes_from_documents(documents)
|
| 65 |
+
index = VectorStoreIndex(nodes, verbose=True)
|
| 66 |
+
# index = VectorStoreIndex.from_documents(documents)
|
| 67 |
index.storage_context.persist(persist_dir='models/precomputed_index')
|
| 68 |
bm25_retriever = BM25Retriever.from_defaults(
|
| 69 |
+
nodes=nodes,
|
| 70 |
+
similarity_top_k=5,
|
| 71 |
+
verbose=True
|
| 72 |
)
|
| 73 |
bm25_retriever.persist("models/bm25_retriever")
|
| 74 |
|
|
|
|
| 86 |
|
| 87 |
retriever = QueryFusionRetriever(
|
| 88 |
[
|
| 89 |
+
index.as_retriever(similarity_top_k=5, verbose=True),
|
| 90 |
bm25_retriever,
|
| 91 |
],
|
| 92 |
llm=llm,
|
| 93 |
num_queries=1,
|
| 94 |
similarity_top_k=5,
|
| 95 |
+
verbose=True
|
| 96 |
)
|
| 97 |
+
reranker = SentenceTransformerRerank(
|
| 98 |
+
model="cross-encoder/ms-marco-MiniLM-L-2-v2",
|
| 99 |
+
top_n=5
|
| 100 |
)
|
| 101 |
keyword_query_engine = RetrieverQueryEngine(
|
| 102 |
retriever=retriever,
|
| 103 |
+
node_postprocessors=[reranker],
|
| 104 |
)
|
| 105 |
|
| 106 |
if is_harmful(query):
|