Spaces:
Running
Running
GitHub Actions commited on
Commit ·
ca3f128
1
Parent(s): ce302e0
Auto-format code with isort and black
Browse files- benchmarks/retrieval/retrieve.py +2 -4
- benchmarks/retrieval/retrieve_kaggle.py +3 -1
- sage/index.py +0 -1
- sage/vector_store.py +5 -3
benchmarks/retrieval/retrieve.py
CHANGED
|
@@ -33,7 +33,7 @@ def main():
|
|
| 33 |
parser.add(
|
| 34 |
"--logs-dir",
|
| 35 |
default=None,
|
| 36 |
-
help="Path where to output predictions and metrics. Optional, since metrics are also printed to console."
|
| 37 |
)
|
| 38 |
parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
|
| 39 |
|
|
@@ -72,9 +72,7 @@ def main():
|
|
| 72 |
# the retrived documents. The key of the score varies depending on the underlying retriever. If there's no
|
| 73 |
# score, we use 1/(doc_idx+1) since it preserves the order of the documents.
|
| 74 |
score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
|
| 75 |
-
retrieved_docs.append(
|
| 76 |
-
ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score)
|
| 77 |
-
)
|
| 78 |
# Update the output dictionary with the retrieved documents.
|
| 79 |
item["retrieved"].append({"file_path": doc.metadata["file_path"], "score": score})
|
| 80 |
|
|
|
|
| 33 |
parser.add(
|
| 34 |
"--logs-dir",
|
| 35 |
default=None,
|
| 36 |
+
help="Path where to output predictions and metrics. Optional, since metrics are also printed to console.",
|
| 37 |
)
|
| 38 |
parser.add("--max-instances", default=None, type=int, help="Maximum number of instances to process.")
|
| 39 |
|
|
|
|
| 72 |
# the retrived documents. The key of the score varies depending on the underlying retriever. If there's no
|
| 73 |
# score, we use 1/(doc_idx+1) since it preserves the order of the documents.
|
| 74 |
score = doc.metadata.get("score", doc.metadata.get("relevance_score", 1 / (doc_idx + 1)))
|
| 75 |
+
retrieved_docs.append(ScoredDoc(query_id=query_id, doc_id=doc.metadata["file_path"], score=score))
|
|
|
|
|
|
|
| 76 |
# Update the output dictionary with the retrieved documents.
|
| 77 |
item["retrieved"].append({"file_path": doc.metadata["file_path"], "score": score})
|
| 78 |
|
benchmarks/retrieval/retrieve_kaggle.py
CHANGED
|
@@ -40,7 +40,9 @@ def main():
|
|
| 40 |
|
| 41 |
retrieved = retriever.invoke(item["question"])
|
| 42 |
# Sort by score in descending order.
|
| 43 |
-
retrieved = sorted(
|
|
|
|
|
|
|
| 44 |
# Keep top 3, since the Kaggle competition only evaluates the top 3.
|
| 45 |
retrieved = retrieved[:3]
|
| 46 |
retrieved_filenames = [doc.metadata["file_path"] for doc in retrieved]
|
|
|
|
| 40 |
|
| 41 |
retrieved = retriever.invoke(item["question"])
|
| 42 |
# Sort by score in descending order.
|
| 43 |
+
retrieved = sorted(
|
| 44 |
+
retrieved, key=lambda doc: doc.metadata.get("score", doc.metadata.get("relevance_score")), reverse=True
|
| 45 |
+
)
|
| 46 |
# Keep top 3, since the Kaggle competition only evaluates the top 3.
|
| 47 |
retrieved = retrieved[:3]
|
| 48 |
retrieved_filenames = [doc.metadata["file_path"] for doc in retrieved]
|
sage/index.py
CHANGED
|
@@ -42,7 +42,6 @@ def main():
|
|
| 42 |
if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
|
| 43 |
parser.error("When using the marqo embedder, the vector store type must also be marqo.")
|
| 44 |
|
| 45 |
-
|
| 46 |
######################
|
| 47 |
# Step 1: Embeddings #
|
| 48 |
######################
|
|
|
|
| 42 |
if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
|
| 43 |
parser.error("When using the marqo embedder, the vector store type must also be marqo.")
|
| 44 |
|
|
|
|
| 45 |
######################
|
| 46 |
# Step 1: Embeddings #
|
| 47 |
######################
|
sage/vector_store.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
"""Vector store abstraction and implementations."""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
-
import nltk
|
| 5 |
import os
|
| 6 |
from abc import ABC, abstractmethod
|
| 7 |
from functools import cached_property
|
| 8 |
from typing import Dict, Generator, List, Optional, Tuple
|
| 9 |
|
| 10 |
import marqo
|
|
|
|
| 11 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
| 12 |
from langchain_community.vectorstores import Marqo
|
| 13 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
|
@@ -22,13 +22,15 @@ from sage.data_manager import DataManager
|
|
| 22 |
|
| 23 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 24 |
|
|
|
|
| 25 |
def is_punkt_downloaded():
|
| 26 |
try:
|
| 27 |
-
find(
|
| 28 |
return True
|
| 29 |
except LookupError:
|
| 30 |
return False
|
| 31 |
|
|
|
|
| 32 |
class VectorStore(ABC):
|
| 33 |
"""Abstract class for a vector store."""
|
| 34 |
|
|
@@ -83,7 +85,7 @@ class PineconeVectorStore(VectorStore):
|
|
| 83 |
else:
|
| 84 |
print("punkt is not downloaded")
|
| 85 |
# Optionally download it
|
| 86 |
-
nltk.download(
|
| 87 |
self.bm25_encoder = BM25Encoder()
|
| 88 |
self.bm25_encoder.load(path=bm25_cache)
|
| 89 |
else:
|
|
|
|
| 1 |
"""Vector store abstraction and implementations."""
|
| 2 |
|
| 3 |
import logging
|
|
|
|
| 4 |
import os
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
from functools import cached_property
|
| 7 |
from typing import Dict, Generator, List, Optional, Tuple
|
| 8 |
|
| 9 |
import marqo
|
| 10 |
+
import nltk
|
| 11 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
| 12 |
from langchain_community.vectorstores import Marqo
|
| 13 |
from langchain_community.vectorstores import Pinecone as LangChainPinecone
|
|
|
|
| 22 |
|
| 23 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 24 |
|
| 25 |
+
|
| 26 |
def is_punkt_downloaded():
|
| 27 |
try:
|
| 28 |
+
find("tokenizers/punkt_tab")
|
| 29 |
return True
|
| 30 |
except LookupError:
|
| 31 |
return False
|
| 32 |
|
| 33 |
+
|
| 34 |
class VectorStore(ABC):
|
| 35 |
"""Abstract class for a vector store."""
|
| 36 |
|
|
|
|
| 85 |
else:
|
| 86 |
print("punkt is not downloaded")
|
| 87 |
# Optionally download it
|
| 88 |
+
nltk.download("punkt_tab")
|
| 89 |
self.bm25_encoder = BM25Encoder()
|
| 90 |
self.bm25_encoder.load(path=bm25_cache)
|
| 91 |
else:
|