Spaces:
Running
Running
add different chunk tables and emb models
Browse files- app.py +5 -2
- backend/semantic_search.py +5 -3
app.py
CHANGED
|
@@ -44,7 +44,9 @@ def bot(history, api_kind, chunk_table, embedding_model, llm_model, eross_encode
|
|
| 44 |
# Retrieve documents relevant to query
|
| 45 |
document_start = perf_counter()
|
| 46 |
|
| 47 |
-
documents = retrieve(query, TOP_K)
|
|
|
|
|
|
|
| 48 |
|
| 49 |
document_time = perf_counter() - document_start
|
| 50 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
|
@@ -118,10 +120,11 @@ with gr.Blocks() as demo:
|
|
| 118 |
)
|
| 119 |
eross_encoder = gr.Radio(
|
| 120 |
choices=[
|
|
|
|
| 121 |
"BAAI/bge-reranker-large",
|
| 122 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 123 |
],
|
| 124 |
-
value="
|
| 125 |
label='Cross-encoder model'
|
| 126 |
)
|
| 127 |
top_k_param = gr.Radio(
|
|
|
|
| 44 |
# Retrieve documents relevant to query
|
| 45 |
document_start = perf_counter()
|
| 46 |
|
| 47 |
+
#documents = retrieve(query, TOP_K)
|
| 48 |
+
documents = retrieve(query, top_k_param, chunk_table, embedding_model)
|
| 49 |
+
|
| 50 |
|
| 51 |
document_time = perf_counter() - document_start
|
| 52 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
|
|
|
| 120 |
)
|
| 121 |
eross_encoder = gr.Radio(
|
| 122 |
choices=[
|
| 123 |
+
"None"
|
| 124 |
"BAAI/bge-reranker-large",
|
| 125 |
"cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 126 |
],
|
| 127 |
+
value="None",
|
| 128 |
label='Cross-encoder model'
|
| 129 |
)
|
| 130 |
top_k_param = gr.Radio(
|
backend/semantic_search.py
CHANGED
|
@@ -6,15 +6,17 @@ from sentence_transformers import SentenceTransformer
|
|
| 6 |
|
| 7 |
db = lancedb.connect(".lancedb")
|
| 8 |
|
| 9 |
-
TABLE = db.open_table(os.getenv("TABLE_NAME"))
|
| 10 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
| 11 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
| 12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
| 13 |
|
| 14 |
-
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
| 15 |
|
| 16 |
|
| 17 |
-
def retrieve(query, k):
|
|
|
|
|
|
|
| 18 |
query_vec = retriever.encode(query)
|
| 19 |
try:
|
| 20 |
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
|
|
|
|
| 6 |
|
| 7 |
db = lancedb.connect(".lancedb")
|
| 8 |
|
| 9 |
+
#TABLE = db.open_table(os.getenv("TABLE_NAME"))
|
| 10 |
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
| 11 |
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
| 12 |
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
| 13 |
|
| 14 |
+
#retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
| 15 |
|
| 16 |
|
| 17 |
+
def retrieve(query, k, table_name, emb_name):
|
| 18 |
+
TABLE = db.open_table(table_name)
|
| 19 |
+
retriever = SentenceTransformer(emb_name)
|
| 20 |
query_vec = retriever.encode(query)
|
| 21 |
try:
|
| 22 |
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
|