GitHub Actions commited on
Commit
ca3f128
·
1 Parent(s): ce302e0

Auto-format code with isort and black

Browse files
benchmarks/retrieval/retrieve.py CHANGED
@@ -33,7 +33,7 @@ def main():
33
  parser.add(
34
  "--logs-dir",
35
  default=None,
36
- help="Path where to output predictions and metrics. Optional, since metrics are also printed to console."
37
  )
38
  parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
39
 
@@ -72,9 +72,7 @@ def main():
72
  # the retrived documents. The key of the score varies depending on the underlying retriever. If there's no
73
  # score, we use 1/(doc_idx+1) since it preserves the order of the documents.
74
  score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
75
- retrieved_docs.append(
76
- ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score)
77
- )
78
  # Update the output dictionary with the retrieved documents.
79
  item["retrieved"].append({"file_path": doc.metadata["file_path"], "score": score})
80
 
 
33
  parser.add(
34
  "--logs-dir",
35
  default=None,
36
+ help="Path where to output predictions and metrics. Optional, since metrics are also printed to console.",
37
  )
38
  parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
39
 
 
72
  # the retrived documents. The key of the score varies depending on the underlying retriever. If there's no
73
  # score, we use 1/(doc_idx+1) since it preserves the order of the documents.
74
  score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
75
+ retrieved_docs.append(ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score))
 
 
76
  # Update the output dictionary with the retrieved documents.
77
  item["retrieved"].append({"file_path": doc.metadata["file_path"], "score": score})
78
 
benchmarks/retrieval/retrieve_kaggle.py CHANGED
@@ -40,7 +40,9 @@ def main():
40
 
41
  retrieved = retriever.invoke(item["question"])
42
  # Sort by score in descending order.
43
- retrieved = sorted(retrieved, key=lambda doc: doc.metadata.get("score", doc.metadata.get("relevance_score")), reverse=True)
 
 
44
  # Keep top 3, since the Kaggle competition only evaluates the top 3.
45
  retrieved = retrieved[:3]
46
  retrieved_filenames = [doc.metadata["file_path"] for doc in retrieved]
 
40
 
41
  retrieved = retriever.invoke(item["question"])
42
  # Sort by score in descending order.
43
+ retrieved = sorted(
44
+ retrieved, key=lambda doc: doc.metadata.get("score", doc.metadata.get("relevance_score")), reverse=True
45
+ )
46
  # Keep top 3, since the Kaggle competition only evaluates the top 3.
47
  retrieved = retrieved[:3]
48
  retrieved_filenames = [doc.metadata["file_path"] for doc in retrieved]
sage/index.py CHANGED
@@ -42,7 +42,6 @@ def main():
42
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
43
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
44
 
45
-
46
  ######################
47
  # Step 1: Embeddings #
48
  ######################
 
42
  if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
43
  parser.error("When using the marqo embedder, the vector store type must also be marqo.")
44
 
 
45
  ######################
46
  # Step 1: Embeddings #
47
  ######################
sage/vector_store.py CHANGED
@@ -1,13 +1,13 @@
1
  """Vector store abstraction and implementations."""
2
 
3
  import logging
4
- import nltk
5
  import os
6
  from abc import ABC, abstractmethod
7
  from functools import cached_property
8
  from typing import Dict, Generator, List, Optional, Tuple
9
 
10
  import marqo
 
11
  from langchain_community.retrievers import PineconeHybridSearchRetriever
12
  from langchain_community.vectorstores import Marqo
13
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
@@ -22,13 +22,15 @@ from sage.data_manager import DataManager
22
 
23
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
24
 
 
25
  def is_punkt_downloaded():
26
  try:
27
- find('tokenizers/punkt_tab')
28
  return True
29
  except LookupError:
30
  return False
31
 
 
32
  class VectorStore(ABC):
33
  """Abstract class for a vector store."""
34
 
@@ -83,7 +85,7 @@ class PineconeVectorStore(VectorStore):
83
  else:
84
  print("punkt is not downloaded")
85
  # Optionally download it
86
- nltk.download('punkt_tab')
87
  self.bm25_encoder = BM25Encoder()
88
  self.bm25_encoder.load(path=bm25_cache)
89
  else:
 
1
  """Vector store abstraction and implementations."""
2
 
3
  import logging
 
4
  import os
5
  from abc import ABC, abstractmethod
6
  from functools import cached_property
7
  from typing import Dict, Generator, List, Optional, Tuple
8
 
9
  import marqo
10
+ import nltk
11
  from langchain_community.retrievers import PineconeHybridSearchRetriever
12
  from langchain_community.vectorstores import Marqo
13
  from langchain_community.vectorstores import Pinecone as LangChainPinecone
 
22
 
23
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
24
 
25
+
26
  def is_punkt_downloaded():
27
  try:
28
+ find("tokenizers/punkt_tab")
29
  return True
30
  except LookupError:
31
  return False
32
 
33
+
34
  class VectorStore(ABC):
35
  """Abstract class for a vector store."""
36
 
 
85
  else:
86
  print("punkt is not downloaded")
87
  # Optionally download it
88
+ nltk.download("punkt_tab")
89
  self.bm25_encoder = BM25Encoder()
90
  self.bm25_encoder.load(path=bm25_cache)
91
  else: