0chanly commited on
Commit
d15fa5d
·
verified ·
1 Parent(s): 370c4cf

Fix handler - remove no_repeat_ngram_size parameter that may not be supported by inference endpoints

Browse files
Files changed (1) hide show
  1. handler.py +6 -5
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  """
2
- Custom handler for Constitutional AI models
 
3
  """
4
 
5
  from typing import Dict, List, Any
@@ -45,14 +46,14 @@ class EndpointHandler:
45
  inputs = data.pop("inputs", data)
46
  parameters = data.pop("parameters", {})
47
 
48
- # Set default parameters to match local chatbot
49
  max_new_tokens = parameters.get("max_new_tokens", 180)
50
  temperature = parameters.get("temperature", 0.7)
51
  do_sample = parameters.get("do_sample", True)
52
  top_p = parameters.get("top_p", 0.9)
53
  top_k = parameters.get("top_k", 50)
54
  repetition_penalty = parameters.get("repetition_penalty", 1.2)
55
- no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3)
56
 
57
  # Tokenize
58
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
@@ -61,7 +62,7 @@ class EndpointHandler:
61
  if torch.cuda.is_available():
62
  input_ids = input_ids.cuda()
63
 
64
- # Generate with parameters matching local chatbot
65
  with torch.no_grad():
66
  outputs = self.model.generate(
67
  input_ids,
@@ -71,7 +72,7 @@ class EndpointHandler:
71
  top_p=top_p,
72
  top_k=top_k,
73
  repetition_penalty=repetition_penalty,
74
- no_repeat_ngram_size=no_repeat_ngram_size,
75
  pad_token_id=self.tokenizer.pad_token_id,
76
  eos_token_id=self.tokenizer.eos_token_id
77
  )
 
1
  """
2
+ Custom handler for Constitutional AI models - Fixed version
3
+ Removed no_repeat_ngram_size which may not be supported
4
  """
5
 
6
  from typing import Dict, List, Any
 
46
  inputs = data.pop("inputs", data)
47
  parameters = data.pop("parameters", {})
48
 
49
+ # Set default parameters to match local chatbot (without no_repeat_ngram_size)
50
  max_new_tokens = parameters.get("max_new_tokens", 180)
51
  temperature = parameters.get("temperature", 0.7)
52
  do_sample = parameters.get("do_sample", True)
53
  top_p = parameters.get("top_p", 0.9)
54
  top_k = parameters.get("top_k", 50)
55
  repetition_penalty = parameters.get("repetition_penalty", 1.2)
56
+ # REMOVED: no_repeat_ngram_size - may not be supported
57
 
58
  # Tokenize
59
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
 
62
  if torch.cuda.is_available():
63
  input_ids = input_ids.cuda()
64
 
65
+ # Generate with parameters matching local chatbot (minus unsupported params)
66
  with torch.no_grad():
67
  outputs = self.model.generate(
68
  input_ids,
 
72
  top_p=top_p,
73
  top_k=top_k,
74
  repetition_penalty=repetition_penalty,
75
+ # REMOVED: no_repeat_ngram_size
76
  pad_token_id=self.tokenizer.pad_token_id,
77
  eos_token_id=self.tokenizer.eos_token_id
78
  )