juliaturc commited on
Commit
6600da0
·
1 Parent(s): f2ad04a

Fix Marqo namespace issue. (#45)

Browse files
Files changed (1) hide show
  1. sage/config.py +9 -8
sage/config.py CHANGED
@@ -30,6 +30,7 @@ OPENAI_DEFAULT_EMBEDDING_SIZE = {
30
 
31
  VOYAGE_MAX_CHUNKS_PER_BATCH = 128
32
 
 
33
  def get_voyage_max_tokens_per_batch(model: str) -> int:
34
  """Returns the maximum number of tokens per batch for the Voyage model.
35
  See https://docs.voyageai.com/reference/embeddings-api."""
@@ -39,6 +40,7 @@ def get_voyage_max_tokens_per_batch(model: str) -> int:
39
  return 320_000
40
  return 120_000
41
 
 
42
  def get_voyage_embedding_size(model: str) -> int:
43
  """Returns the embedding size for the Voyage model. See https://docs.voyageai.com/docs/embeddings#model-choices."""
44
  if model == "voyage-3-lite":
@@ -141,10 +143,7 @@ def add_vector_store_args(parser: ArgumentParser) -> Callable:
141
  "encoder in the final retrieval score. A value of 0.0 means BM25 only, 1.0 means embeddings only.",
142
  )
143
  parser.add(
144
- "--retriever-top-k",
145
- default=25,
146
- type=int,
147
- help="The number of top documents to retrieve from the vector store."
148
  )
149
  return validate_vector_store_args
150
 
@@ -274,8 +273,10 @@ def _validate_voyage_embedding_args(args):
274
 
275
  max_tokens = get_voyage_max_tokens_per_batch(args.embedding_model)
276
  if args.tokens_per_chunk * args.chunks_per_batch > max_tokens:
277
- raise ValueError(f"Voyage enforces a limit of {max_tokens} tokens per batch. "
278
- "Reduce either --tokens-per-chunk or --chunks-per-batch.")
 
 
279
 
280
  if not args.embedding_size:
281
  args.embedding_size = get_voyage_embedding_size(args.embedding_model)
@@ -319,8 +320,8 @@ def validate_vector_store_args(args):
319
  if "commit_hash" in args and args.commit_hash:
320
  args.index_namespace += "/" + args.commit_hash
321
  if args.vector_store_provider == "marqo":
322
- # Marqo doesn't allow slashes in the index namespace.
323
- args.index_namespace = args.index_namespace.replace("/", "_")
324
 
325
  if args.vector_store_provider == "marqo":
326
  if not args.marqo_url:
 
30
 
31
  VOYAGE_MAX_CHUNKS_PER_BATCH = 128
32
 
33
+
34
  def get_voyage_max_tokens_per_batch(model: str) -> int:
35
  """Returns the maximum number of tokens per batch for the Voyage model.
36
  See https://docs.voyageai.com/reference/embeddings-api."""
 
40
  return 320_000
41
  return 120_000
42
 
43
+
44
  def get_voyage_embedding_size(model: str) -> int:
45
  """Returns the embedding size for the Voyage model. See https://docs.voyageai.com/docs/embeddings#model-choices."""
46
  if model == "voyage-3-lite":
 
143
  "encoder in the final retrieval score. A value of 0.0 means BM25 only, 1.0 means embeddings only.",
144
  )
145
  parser.add(
146
+ "--retriever-top-k", default=25, type=int, help="The number of top documents to retrieve from the vector store."
 
 
 
147
  )
148
  return validate_vector_store_args
149
 
 
273
 
274
  max_tokens = get_voyage_max_tokens_per_batch(args.embedding_model)
275
  if args.tokens_per_chunk * args.chunks_per_batch > max_tokens:
276
+ raise ValueError(
277
+ f"Voyage enforces a limit of {max_tokens} tokens per batch. "
278
+ "Reduce either --tokens-per-chunk or --chunks-per-batch."
279
+ )
280
 
281
  if not args.embedding_size:
282
  args.embedding_size = get_voyage_embedding_size(args.embedding_model)
 
320
  if "commit_hash" in args and args.commit_hash:
321
  args.index_namespace += "/" + args.commit_hash
322
  if args.vector_store_provider == "marqo":
323
+ # Marqo namespaces must match this pattern: [a-zA-Z_-][a-zA-Z0-9_-]*
324
+ args.index_namespace = re.sub(r"[^a-zA-Z0-9_-]", "_", args.index_namespace)
325
 
326
  if args.vector_store_provider == "marqo":
327
  if not args.marqo_url: