Spaces:
Running
Running
add reranker
Browse files- app.py +21 -5
- backend/reranker.py +13 -0
app.py
CHANGED
|
@@ -11,6 +11,7 @@ from jinja2 import Environment, FileSystemLoader
|
|
| 11 |
|
| 12 |
from backend.query_llm import generate_hf, generate_openai
|
| 13 |
from backend.semantic_search import retrieve
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
TOP_K = int(os.getenv("TOP_K", 4))
|
|
@@ -34,7 +35,7 @@ def add_text(history, text):
|
|
| 34 |
return history, gr.Textbox(value="", interactive=False)
|
| 35 |
|
| 36 |
|
| 37 |
-
def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encoder,
|
| 38 |
top_k_param = int(top_k_param)
|
| 39 |
query = history[-1][0]
|
| 40 |
|
|
@@ -47,6 +48,11 @@ def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encode
|
|
| 47 |
|
| 48 |
#documents = retrieve(query, TOP_K)
|
| 49 |
documents = retrieve(query, top_k_param, chunk_table, embedding_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
document_time = perf_counter() - document_start
|
|
@@ -121,7 +127,7 @@ with gr.Blocks() as demo:
|
|
| 121 |
)
|
| 122 |
cross_encoder = gr.Radio(
|
| 123 |
choices=[
|
| 124 |
-
"None"
|
| 125 |
"BAAI/bge-reranker-large",
|
| 126 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 127 |
],
|
|
@@ -137,20 +143,30 @@ with gr.Blocks() as demo:
|
|
| 137 |
],
|
| 138 |
value="5",
|
| 139 |
label='top-K'
|
| 140 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
prompt_html = gr.HTML()
|
| 144 |
# Turn off interactivity while generating if you click
|
| 145 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 146 |
-
bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param], [chatbot, prompt_html])
|
| 147 |
|
| 148 |
# Turn it back on
|
| 149 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
| 150 |
|
| 151 |
# Turn off interactivity while generating if you hit enter
|
| 152 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 153 |
-
bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param], [chatbot, prompt_html])
|
| 154 |
|
| 155 |
# Turn it back on
|
| 156 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
|
| 11 |
|
| 12 |
from backend.query_llm import generate_hf, generate_openai
|
| 13 |
from backend.semantic_search import retrieve
|
| 14 |
+
from backend.reranker import rerank_documents
|
| 15 |
|
| 16 |
|
| 17 |
TOP_K = int(os.getenv("TOP_K", 4))
|
|
|
|
| 35 |
return history, gr.Textbox(value="", interactive=False)
|
| 36 |
|
| 37 |
|
| 38 |
+
def bot(history, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, rerank_topk ):
|
| 39 |
top_k_param = int(top_k_param)
|
| 40 |
query = history[-1][0]
|
| 41 |
|
|
|
|
| 48 |
|
| 49 |
#documents = retrieve(query, TOP_K)
|
| 50 |
documents = retrieve(query, top_k_param, chunk_table, embedding_model)
|
| 51 |
+
if cross_encoder != "None" and len(documents) > 1:
|
| 52 |
+
documents = rerank_documents(query, documents, query, top_k_rerank=rerank_topk)
|
| 53 |
+
#"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
|
| 57 |
|
| 58 |
document_time = perf_counter() - document_start
|
|
|
|
| 127 |
)
|
| 128 |
cross_encoder = gr.Radio(
|
| 129 |
choices=[
|
| 130 |
+
"None",
|
| 131 |
"BAAI/bge-reranker-large",
|
| 132 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 133 |
],
|
|
|
|
| 143 |
],
|
| 144 |
value="5",
|
| 145 |
label='top-K'
|
| 146 |
+
)
|
| 147 |
+
rerank_topk = gr.Radio(
|
| 148 |
+
choices=[
|
| 149 |
+
"5",
|
| 150 |
+
"10",
|
| 151 |
+
"20",
|
| 152 |
+
"50",
|
| 153 |
+
],
|
| 154 |
+
value="5",
|
| 155 |
+
label='rerank-top-K'
|
| 156 |
+
)
|
| 157 |
|
| 158 |
|
| 159 |
prompt_html = gr.HTML()
|
| 160 |
# Turn off interactivity while generating if you click
|
| 161 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 162 |
+
bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])
|
| 163 |
|
| 164 |
# Turn it back on
|
| 165 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
| 166 |
|
| 167 |
# Turn off interactivity while generating if you hit enter
|
| 168 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
| 169 |
+
bot, [chatbot, api_kind, chunk_table, embedding_model, llm_model, cross_encoder, top_k_param, rerank_topk], [chatbot, prompt_html])
|
| 170 |
|
| 171 |
# Turn it back on
|
| 172 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
backend/reranker.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import CrossEncoder
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def rerank_documents(ce_model_name, documents, query, top_k_rerank):
|
| 5 |
+
top_k_rerank = int(top_k_rerank)
|
| 6 |
+
pairs = []
|
| 7 |
+
for doc in documents:
|
| 8 |
+
pairs.append((query, doc))
|
| 9 |
+
ce_model = CrossEncoder(ce_model_name, max_length=512)
|
| 10 |
+
scores = ce_model.predict(pairs)
|
| 11 |
+
#sorted_pairs = [(s, x[1]) for s, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
|
| 12 |
+
reranked_docs = [x[1] for _, x in sorted(zip(scores, pairs), key=lambda p: p[0], reverse = True)]
|
| 13 |
+
return reranked_docs[:top_k_rerank]
|