taylorj94 commited on
Commit
b3ebb95
·
verified ·
1 Parent(s): 53c6e74

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -6
handler.py CHANGED
@@ -18,15 +18,18 @@ class EndpointHandler:
18
  def get_allowed_token_ids(self, vocab_list: List[str]) -> set[int]:
19
  """
20
  Generate a set of token IDs for a given list of allowed words.
21
- Includes both plain and space-prefixed forms of each word.
22
  """
23
  allowed_ids = set()
24
  for word in vocab_list:
25
- # Add plain and space-prefixed token IDs
26
- for token_id in self.tokenizer.encode(word, add_special_tokens=False):
27
- allowed_ids.add(token_id)
28
- for token_id in self.tokenizer.encode(" " + word, add_special_tokens=False):
29
- allowed_ids.add(token_id)
 
 
 
30
  return allowed_ids
31
 
32
  def filter_allowed_tokens(self, input_ids: torch.Tensor, scores: np.ndarray, allowed_token_ids: set[int]) -> np.ndarray:
 
18
  def get_allowed_token_ids(self, vocab_list: List[str]) -> set[int]:
19
  """
20
  Generate a set of token IDs for a given list of allowed words.
21
+ Includes plain, space-prefixed, capitalized, and uppercase forms of each word.
22
  """
23
  allowed_ids = set()
24
  for word in vocab_list:
25
+ # Generate all variations: plain, space-prefixed, capitalized, and uppercase
26
+ variations = {word, " " + word, word.capitalize(), " " + word.capitalize(), word.upper(), " " + word.upper()}
27
+
28
+ # Add token IDs for all variations
29
+ for variation in variations:
30
+ for token_id in self.tokenizer.encode(variation, add_special_tokens=False):
31
+ allowed_ids.add(token_id)
32
+
33
  return allowed_ids
34
 
35
  def filter_allowed_tokens(self, input_ids: torch.Tensor, scores: np.ndarray, allowed_token_ids: set[int]) -> np.ndarray: