Julia Turc commited on
Commit
90c4308
·
1 Parent(s): fcecd4b

Adapt to Anthropic's new count_tokens API

Browse files
Files changed (2) hide show
  1. sage/config.py +4 -1
  2. sage/retriever.py +8 -4
sage/config.py CHANGED
@@ -163,7 +163,7 @@ def add_vector_store_args(parser: ArgumentParser) -> Callable:
163
  parser.add(
164
  "--llm-retriever",
165
  action=argparse.BooleanOptionalAction,
166
- default=False,
167
  help="When set to True, we use an LLM for retrieval: we pass the repository file hierarchy together with the "
168
  "user query and ask the LLM to choose relevant files solely based on their paths. No indexing will be done, so "
169
  "all the vector store / embedding arguments will be ignored.",
@@ -358,6 +358,9 @@ def _validate_gemini_embedding_args(args):
358
 
359
  def validate_embedding_args(args):
360
  """Validates the configuration of the batch embedder and sets defaults."""
 
 
 
361
  if args.embedding_provider == "openai":
362
  _validate_openai_embedding_args(args)
363
  elif args.embedding_provider == "voyage":
 
163
  parser.add(
164
  "--llm-retriever",
165
  action=argparse.BooleanOptionalAction,
166
+ default=True,
167
  help="When set to True, we use an LLM for retrieval: we pass the repository file hierarchy together with the "
168
  "user query and ask the LLM to choose relevant files solely based on their paths. No indexing will be done, so "
169
  "all the vector store / embedding arguments will be ignored.",
 
358
 
359
  def validate_embedding_args(args):
360
  """Validates the configuration of the batch embedder and sets defaults."""
361
+ if args.llm_retriever:
362
+ # When using an LLM to retrieve, we are not running the embedder.
363
+ return True
364
  if args.embedding_provider == "openai":
365
  _validate_openai_embedding_args(args)
366
  elif args.embedding_provider == "voyage":
sage/retriever.py CHANGED
@@ -91,22 +91,26 @@ class LLMRetriever(BaseRetriever):
91
  render = LLMRetriever._render_file_hierarchy(self.repo_metadata, include_classes=True, include_methods=True)
92
  max_tokens = CLAUDE_MODEL_CONTEXT_SIZE - 50_000 # 50,000 tokens for other parts of the prompt.
93
  client = anthropic.Anthropic()
94
- if client.count_tokens(render) > max_tokens:
 
 
 
 
 
95
  logging.info("File hierarchy is too large; excluding methods.")
96
  render = LLMRetriever._render_file_hierarchy(
97
  self.repo_metadata, include_classes=True, include_methods=False
98
  )
99
- if client.count_tokens(render) > max_tokens:
100
  logging.info("File hierarchy is still too large; excluding classes.")
101
  render = LLMRetriever._render_file_hierarchy(
102
  self.repo_metadata, include_classes=False, include_methods=False
103
  )
104
- if client.count_tokens(render) > max_tokens:
105
  logging.info("File hierarchy is still too large; truncating.")
106
  tokenizer = anthropic.Tokenizer()
107
  tokens = tokenizer.tokenize(render)[:max_tokens]
108
  render = tokenizer.detokenize(tokens)
109
- logging.info("Number of tokens in render hierarchy: %d", client.count_tokens(render))
110
  self.cached_repo_hierarchy = render
111
  return self.cached_repo_hierarchy
112
 
 
91
  render = LLMRetriever._render_file_hierarchy(self.repo_metadata, include_classes=True, include_methods=True)
92
  max_tokens = CLAUDE_MODEL_CONTEXT_SIZE - 50_000 # 50,000 tokens for other parts of the prompt.
93
  client = anthropic.Anthropic()
94
+
95
+ def count_tokens(x):
96
+ count = client.beta.messages.count_tokens(model=CLAUDE_MODEL, messages=[{"role": "user", "content": x}])
97
+ return count.input_tokens
98
+
99
+ if count_tokens(render) > max_tokens:
100
  logging.info("File hierarchy is too large; excluding methods.")
101
  render = LLMRetriever._render_file_hierarchy(
102
  self.repo_metadata, include_classes=True, include_methods=False
103
  )
104
+ if count_tokens(render) > max_tokens:
105
  logging.info("File hierarchy is still too large; excluding classes.")
106
  render = LLMRetriever._render_file_hierarchy(
107
  self.repo_metadata, include_classes=False, include_methods=False
108
  )
109
+ if count_tokens(render) > max_tokens:
110
  logging.info("File hierarchy is still too large; truncating.")
111
  tokenizer = anthropic.Tokenizer()
112
  tokens = tokenizer.tokenize(render)[:max_tokens]
113
  render = tokenizer.detokenize(tokens)
 
114
  self.cached_repo_hierarchy = render
115
  return self.cached_repo_hierarchy
116