Rifqi Hafizuddin commited on
Commit ·
110ee34
1
Parent(s): bd2b1d9
[NOTICKET] fix-revert string change
Browse files- src/rag/retrievers/schema.py +10 -10
src/rag/retrievers/schema.py
CHANGED
|
@@ -52,11 +52,11 @@ class SchemaRetriever(BaseRetriever):
|
|
| 52 |
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 53 |
|
| 54 |
if operator == "<#>":
|
| 55 |
-
score_sql = "(lpe.embedding <#> :
|
| 56 |
elif operator == "<->":
|
| 57 |
-
score_sql = "1.0 / (1.0 + (lpe.embedding <-> :
|
| 58 |
else:
|
| 59 |
-
score_sql = "1.0 - (lpe.embedding <=> :
|
| 60 |
|
| 61 |
sql = text(f"""
|
| 62 |
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
|
@@ -65,12 +65,12 @@ class SchemaRetriever(BaseRetriever):
|
|
| 65 |
WHERE lpc.name = 'document_embeddings'
|
| 66 |
AND lpe.cmetadata->>'user_id' = :user_id
|
| 67 |
AND lpe.cmetadata->>'source_type' = 'database'
|
| 68 |
-
ORDER BY lpe.embedding {operator} :
|
| 69 |
LIMIT :k
|
| 70 |
""")
|
| 71 |
|
| 72 |
async with _pgvector_engine.connect() as conn:
|
| 73 |
-
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4
|
| 74 |
rows = result.fetchall()
|
| 75 |
|
| 76 |
return [
|
|
@@ -90,11 +90,11 @@ class SchemaRetriever(BaseRetriever):
|
|
| 90 |
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 91 |
|
| 92 |
if operator == "<#>":
|
| 93 |
-
score_sql = "(lpe.embedding <#> :
|
| 94 |
elif operator == "<->":
|
| 95 |
-
score_sql = "1.0 / (1.0 + (lpe.embedding <-> :
|
| 96 |
else:
|
| 97 |
-
score_sql = "1.0 - (lpe.embedding <=> :
|
| 98 |
|
| 99 |
sql = text(f"""
|
| 100 |
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
|
@@ -105,12 +105,12 @@ class SchemaRetriever(BaseRetriever):
|
|
| 105 |
AND lpe.cmetadata->>'source_type' = 'document'
|
| 106 |
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 107 |
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 108 |
-
ORDER BY lpe.embedding {operator} :
|
| 109 |
LIMIT :k
|
| 110 |
""")
|
| 111 |
|
| 112 |
async with _pgvector_engine.connect() as conn:
|
| 113 |
-
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4
|
| 114 |
rows = result.fetchall()
|
| 115 |
|
| 116 |
results = []
|
|
|
|
| 52 |
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 53 |
|
| 54 |
if operator == "<#>":
|
| 55 |
+
score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
|
| 56 |
elif operator == "<->":
|
| 57 |
+
score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
|
| 58 |
else:
|
| 59 |
+
score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
|
| 60 |
|
| 61 |
sql = text(f"""
|
| 62 |
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
|
|
|
| 65 |
WHERE lpc.name = 'document_embeddings'
|
| 66 |
AND lpe.cmetadata->>'user_id' = :user_id
|
| 67 |
AND lpe.cmetadata->>'source_type' = 'database'
|
| 68 |
+
ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
|
| 69 |
LIMIT :k
|
| 70 |
""")
|
| 71 |
|
| 72 |
async with _pgvector_engine.connect() as conn:
|
| 73 |
+
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
|
| 74 |
rows = result.fetchall()
|
| 75 |
|
| 76 |
return [
|
|
|
|
| 90 |
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 91 |
|
| 92 |
if operator == "<#>":
|
| 93 |
+
score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
|
| 94 |
elif operator == "<->":
|
| 95 |
+
score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
|
| 96 |
else:
|
| 97 |
+
score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
|
| 98 |
|
| 99 |
sql = text(f"""
|
| 100 |
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
|
|
|
| 105 |
AND lpe.cmetadata->>'source_type' = 'document'
|
| 106 |
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 107 |
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 108 |
+
ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
|
| 109 |
LIMIT :k
|
| 110 |
""")
|
| 111 |
|
| 112 |
async with _pgvector_engine.connect() as conn:
|
| 113 |
+
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
|
| 114 |
rows = result.fetchall()
|
| 115 |
|
| 116 |
results = []
|