taylorj94 commited on
Commit
430bd60
·
verified ·
1 Parent(s): 23aac07

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -7
handler.py CHANGED
@@ -2,7 +2,7 @@ from typing import Any, List, Dict
2
  from llama_cpp import Llama
3
  import numpy as np
4
  import torch
5
- from transformers import AutoTokenizer
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
@@ -49,15 +49,16 @@ class EndpointHandler:
49
  if not inputs:
50
  raise ValueError("The 'inputs' field is required.")
51
 
52
- logits_processor = None
53
-
54
  if vocab_list:
55
  # Define allowed tokens dynamically
56
  allowed_token_ids = self.get_allowed_token_ids(vocab_list)
57
 
58
- # Define the logits processor if vocab_list is provided
59
- def logits_processor(input_ids, scores):
60
- return self.filter_allowed_tokens(input_ids, scores, allowed_token_ids)
 
61
 
62
  # Tokenize input
63
  input_ids = torch.tensor([self.tokenizer.encode(inputs, add_special_tokens=False)])
@@ -68,7 +69,7 @@ class EndpointHandler:
68
  {"role": "user", "content": inputs}
69
  ],
70
  max_tokens=parameters.get("max_length", 30),
71
- logits_processor=logits_processor,
72
  temperature=parameters.get("temperature", 1),
73
  repeat_penalty=parameters.get("repeat_penalty", 1.0)
74
  )
 
2
  from llama_cpp import Llama
3
  import numpy as np
4
  import torch
5
+ from transformers import AutoTokenizer, LogitsProcessorList
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
 
49
  if not inputs:
50
  raise ValueError("The 'inputs' field is required.")
51
 
52
+ # Prepare logits processor
53
+ logits_processors = None
54
  if vocab_list:
55
  # Define allowed tokens dynamically
56
  allowed_token_ids = self.get_allowed_token_ids(vocab_list)
57
 
58
+ # Create LogitsProcessorList with filtering function
59
+ logits_processors = LogitsProcessorList([
60
+ lambda input_ids, scores: self.filter_allowed_tokens(input_ids, scores, allowed_token_ids)
61
+ ])
62
 
63
  # Tokenize input
64
  input_ids = torch.tensor([self.tokenizer.encode(inputs, add_special_tokens=False)])
 
69
  {"role": "user", "content": inputs}
70
  ],
71
  max_tokens=parameters.get("max_length", 30),
72
+ logits_processor=logits_processors, # Pass the LogitsProcessorList here
73
  temperature=parameters.get("temperature", 1),
74
  repeat_penalty=parameters.get("repeat_penalty", 1.0)
75
  )