0chanly commited on
Commit
370c4cf
·
verified ·
1 Parent(s): 8408dc0

Update generation parameters to match local chatbot (max_tokens=180, repetition_penalty=1.2, top_p=0.9, top_k=50)

Browse files
Files changed (1) hide show
  1. handler.py +10 -4
handler.py CHANGED
@@ -45,11 +45,14 @@ class EndpointHandler:
45
  inputs = data.pop("inputs", data)
46
  parameters = data.pop("parameters", {})
47
 
48
- # Set default parameters
49
- max_new_tokens = parameters.get("max_new_tokens", 200)
50
  temperature = parameters.get("temperature", 0.7)
51
  do_sample = parameters.get("do_sample", True)
52
- top_p = parameters.get("top_p", 0.95)
 
 
 
53
 
54
  # Tokenize
55
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
@@ -58,7 +61,7 @@ class EndpointHandler:
58
  if torch.cuda.is_available():
59
  input_ids = input_ids.cuda()
60
 
61
- # Generate
62
  with torch.no_grad():
63
  outputs = self.model.generate(
64
  input_ids,
@@ -66,6 +69,9 @@ class EndpointHandler:
66
  temperature=temperature,
67
  do_sample=do_sample,
68
  top_p=top_p,
 
 
 
69
  pad_token_id=self.tokenizer.pad_token_id,
70
  eos_token_id=self.tokenizer.eos_token_id
71
  )
 
45
  inputs = data.pop("inputs", data)
46
  parameters = data.pop("parameters", {})
47
 
48
+ # Set default parameters to match local chatbot
49
+ max_new_tokens = parameters.get("max_new_tokens", 180)
50
  temperature = parameters.get("temperature", 0.7)
51
  do_sample = parameters.get("do_sample", True)
52
+ top_p = parameters.get("top_p", 0.9)
53
+ top_k = parameters.get("top_k", 50)
54
+ repetition_penalty = parameters.get("repetition_penalty", 1.2)
55
+ no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3)
56
 
57
  # Tokenize
58
  input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
 
61
  if torch.cuda.is_available():
62
  input_ids = input_ids.cuda()
63
 
64
+ # Generate with parameters matching local chatbot
65
  with torch.no_grad():
66
  outputs = self.model.generate(
67
  input_ids,
 
69
  temperature=temperature,
70
  do_sample=do_sample,
71
  top_p=top_p,
72
+ top_k=top_k,
73
+ repetition_penalty=repetition_penalty,
74
+ no_repeat_ngram_size=no_repeat_ngram_size,
75
  pad_token_id=self.tokenizer.pad_token_id,
76
  eos_token_id=self.tokenizer.eos_token_id
77
  )