busybisi commited on
Commit
cfb29ca
·
verified ·
1 Parent(s): fbbbc9e

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +9 -6
handler.py CHANGED
@@ -52,12 +52,12 @@ class EndpointHandler:
52
  inputs = data.pop("inputs", data)
53
  parameters = data.pop("parameters", {})
54
 
55
- # Default generation parameters
56
  max_new_tokens = parameters.get("max_new_tokens", 512)
57
- temperature = parameters.get("temperature", 0.7)
58
- top_p = parameters.get("top_p", 0.9)
59
  do_sample = parameters.get("do_sample", True)
60
- repetition_penalty = parameters.get("repetition_penalty", 1.1)
61
 
62
  # Tokenize input
63
  input_ids = self.tokenizer(
@@ -67,17 +67,20 @@ class EndpointHandler:
67
  max_length=self.model.config.max_position_embeddings - max_new_tokens
68
  ).input_ids.to(self.model.device)
69
 
70
- # Generate response
71
  with torch.no_grad():
72
  outputs = self.model.generate(
73
  input_ids,
74
  max_new_tokens=max_new_tokens,
75
  temperature=temperature,
76
  top_p=top_p,
 
77
  do_sample=do_sample,
78
  repetition_penalty=repetition_penalty,
79
- pad_token_id=self.tokenizer.eos_token_id,
80
  eos_token_id=self.tokenizer.eos_token_id,
 
 
81
  )
82
 
83
  # Decode output
 
52
  inputs = data.pop("inputs", data)
53
  parameters = data.pop("parameters", {})
54
 
55
+ # Default generation parameters with safe values
56
  max_new_tokens = parameters.get("max_new_tokens", 512)
57
+ temperature = max(0.1, min(parameters.get("temperature", 0.8), 2.0)) # Clamp between 0.1 and 2.0
58
+ top_p = max(0.1, min(parameters.get("top_p", 0.95), 1.0)) # Clamp between 0.1 and 1.0
59
  do_sample = parameters.get("do_sample", True)
60
+ repetition_penalty = max(1.0, min(parameters.get("repetition_penalty", 1.05), 2.0)) # Clamp between 1.0 and 2.0
61
 
62
  # Tokenize input
63
  input_ids = self.tokenizer(
 
67
  max_length=self.model.config.max_position_embeddings - max_new_tokens
68
  ).input_ids.to(self.model.device)
69
 
70
+ # Generate response with safe parameters
71
  with torch.no_grad():
72
  outputs = self.model.generate(
73
  input_ids,
74
  max_new_tokens=max_new_tokens,
75
  temperature=temperature,
76
  top_p=top_p,
77
+ top_k=50, # Add top_k for stability
78
  do_sample=do_sample,
79
  repetition_penalty=repetition_penalty,
80
+ pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
81
  eos_token_id=self.tokenizer.eos_token_id,
82
+ bad_words_ids=None, # Ensure no bad words restriction causing issues
83
+ min_length=1, # Ensure at least 1 token is generated
84
  )
85
 
86
  # Decode output