Rifqi Hafizuddin commited on
Commit
110ee34
·
1 Parent(s): bd2b1d9

[NOTICKET] fix-revert string change

Browse files
Files changed (1) hide show
  1. src/rag/retrievers/schema.py +10 -10
src/rag/retrievers/schema.py CHANGED
@@ -52,11 +52,11 @@ class SchemaRetriever(BaseRetriever):
52
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
53
 
54
  if operator == "<#>":
55
- score_sql = "(lpe.embedding <#> :emb::vector) * -1"
56
  elif operator == "<->":
57
- score_sql = "1.0 / (1.0 + (lpe.embedding <-> :emb::vector))"
58
  else:
59
- score_sql = "1.0 - (lpe.embedding <=> :emb::vector)"
60
 
61
  sql = text(f"""
62
  SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
@@ -65,12 +65,12 @@ class SchemaRetriever(BaseRetriever):
65
  WHERE lpc.name = 'document_embeddings'
66
  AND lpe.cmetadata->>'user_id' = :user_id
67
  AND lpe.cmetadata->>'source_type' = 'database'
68
- ORDER BY lpe.embedding {operator} :emb::vector ASC
69
  LIMIT :k
70
  """)
71
 
72
  async with _pgvector_engine.connect() as conn:
73
- result = await conn.execute(sql, {"user_id": user_id, "k": k * 4, "emb": emb_str})
74
  rows = result.fetchall()
75
 
76
  return [
@@ -90,11 +90,11 @@ class SchemaRetriever(BaseRetriever):
90
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
91
 
92
  if operator == "<#>":
93
- score_sql = "(lpe.embedding <#> :emb::vector) * -1"
94
  elif operator == "<->":
95
- score_sql = "1.0 / (1.0 + (lpe.embedding <-> :emb::vector))"
96
  else:
97
- score_sql = "1.0 - (lpe.embedding <=> :emb::vector)"
98
 
99
  sql = text(f"""
100
  SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
@@ -105,12 +105,12 @@ class SchemaRetriever(BaseRetriever):
105
  AND lpe.cmetadata->>'source_type' = 'document'
106
  AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
107
  OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
108
- ORDER BY lpe.embedding {operator} :emb::vector ASC
109
  LIMIT :k
110
  """)
111
 
112
  async with _pgvector_engine.connect() as conn:
113
- result = await conn.execute(sql, {"user_id": user_id, "k": k * 4, "emb": emb_str})
114
  rows = result.fetchall()
115
 
116
  results = []
 
52
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
53
 
54
  if operator == "<#>":
55
+ score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
56
  elif operator == "<->":
57
+ score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
58
  else:
59
+ score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
60
 
61
  sql = text(f"""
62
  SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
 
65
  WHERE lpc.name = 'document_embeddings'
66
  AND lpe.cmetadata->>'user_id' = :user_id
67
  AND lpe.cmetadata->>'source_type' = 'database'
68
+ ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
69
  LIMIT :k
70
  """)
71
 
72
  async with _pgvector_engine.connect() as conn:
73
+ result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
74
  rows = result.fetchall()
75
 
76
  return [
 
90
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
91
 
92
  if operator == "<#>":
93
+ score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
94
  elif operator == "<->":
95
+ score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
96
  else:
97
+ score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
98
 
99
  sql = text(f"""
100
  SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
 
105
  AND lpe.cmetadata->>'source_type' = 'document'
106
  AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
107
  OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
108
+ ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
109
  LIMIT :k
110
  """)
111
 
112
  async with _pgvector_engine.connect() as conn:
113
+ result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
114
  rows = result.fetchall()
115
 
116
  results = []