wbrooks commited on
Commit
49257b2
·
1 Parent(s): f4c877a

use valid columns for the result

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +7 -8
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  # pixi environments
2
  .pixi/*
3
  !.pixi/config.toml
 
 
 
1
  # pixi environments
2
  .pixi/*
3
  !.pixi/config.toml
4
+
5
+ */.DS_Store
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import FastAPI, Query
2
  from fastapi.responses import JSONResponse
3
- from src.embeddings_search import create_embeddings_search_function_from_embeddings_df
4
- from src.tfidf_search import create_tfidf_search_function
5
 
6
  import polars as pl
7
  #from jinja2 import Template
@@ -13,6 +13,7 @@ path_prefix = "/Users/wes/Google Drive/Shared drives/datalab/projects/2025_coul_
13
  block_embeddings_df_path = "block_embeddings/block-embeddings.parquet"
14
  doc_tfidf_df_path = "block_tfidf/TF-IDF-doc-text.parquet"
15
  tfidf_vectorizer_path = "block_tfidf/tfidf_vectorizer_doc_text.joblib"
 
16
 
17
  sbert_query_docs = create_embeddings_search_function_from_embeddings_df(
18
  model_name = "sentence-transformers/all-MiniLM-L6-v2",
@@ -22,9 +23,8 @@ tfidf_query_docs = create_tfidf_search_function(
22
  dtm_df_path = doc_tfidf_df_path,
23
  vectorizer_path = tfidf_vectorizer_path,
24
  model_name = "facebook/fasttext-en-vectors")
25
- doc_embeddings_df = (pl.read_parquet(block_embeddings_df_path)
26
- .unique(subset = "file", keep="first")
27
- .with_columns(("/" + pl.col('file').str.strip_prefix(path_prefix)).alias("tail_path")))
28
 
29
 
30
  app = FastAPI()
@@ -41,16 +41,15 @@ def search(q: str = Query(..., description="Search query")):
41
  res_sbert = sbert_query_docs(q)
42
 
43
  joined = (res_sbert.join(res_tfidf, on='file', how = 'inner')
44
- .join(doc_embeddings_df, left_on="file", right_on = "tail_path", how="inner")
45
  .with_columns(
46
  pl.format('<a href="https://drive.google.com/file/d/{}/view" target="_blank" rel="noopener">{}</a>',
47
  pl.col("id"),
48
- pl.col("name")).alias('link')))
49
 
50
  res_combined = joined.with_columns(
51
  (0.7 * pl.col("rank-sbert") + 0.3 * pl.col("rank-tfidf")).alias("rank-combined"),
52
  #pl.col("file").str.strip_prefix(path_prefix).alias("file"),
53
- pl.col("link").str.strip_prefix(path_prefix).alias("link"),
54
  ).sort("rank-combined").with_columns(
55
  (20.0 / pl.col('rank-combined')).round(2).alias('confidence')
56
  ).select(['link', 'confidence'])
 
1
  from fastapi import FastAPI, Query
2
  from fastapi.responses import JSONResponse
3
+ from src.coul_search.embeddings_search import create_embeddings_search_function_from_embeddings_df
4
+ from src.coul_search.tfidf_search import create_tfidf_search_function
5
 
6
  import polars as pl
7
  #from jinja2 import Template
 
13
  block_embeddings_df_path = "block_embeddings/block-embeddings.parquet"
14
  doc_tfidf_df_path = "block_tfidf/TF-IDF-doc-text.parquet"
15
  tfidf_vectorizer_path = "block_tfidf/tfidf_vectorizer_doc_text.joblib"
16
+ googledrive_metadata_path = "coul_files.csv"
17
 
18
  sbert_query_docs = create_embeddings_search_function_from_embeddings_df(
19
  model_name = "sentence-transformers/all-MiniLM-L6-v2",
 
23
  dtm_df_path = doc_tfidf_df_path,
24
  vectorizer_path = tfidf_vectorizer_path,
25
  model_name = "facebook/fasttext-en-vectors")
26
+ coul_files_df = (pl.read_csv(googledrive_metadata_path)
27
+ .with_columns(pl.col("path").str.strip_prefix("/").alias("path")))
 
28
 
29
 
30
  app = FastAPI()
 
41
  res_sbert = sbert_query_docs(q)
42
 
43
  joined = (res_sbert.join(res_tfidf, on='file', how = 'inner')
44
+ .join(coul_files_df, left_on="file", right_on = "path", how="inner")
45
  .with_columns(
46
  pl.format('<a href="https://drive.google.com/file/d/{}/view" target="_blank" rel="noopener">{}</a>',
47
  pl.col("id"),
48
+ pl.col("file")).alias('link')))
49
 
50
  res_combined = joined.with_columns(
51
  (0.7 * pl.col("rank-sbert") + 0.3 * pl.col("rank-tfidf")).alias("rank-combined"),
52
  #pl.col("file").str.strip_prefix(path_prefix).alias("file"),
 
53
  ).sort("rank-combined").with_columns(
54
  (20.0 / pl.col('rank-combined')).round(2).alias('confidence')
55
  ).select(['link', 'confidence'])