busybisi commited on
Commit
e904f37
·
verified ·
1 Parent(s): cfb29ca

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +8 -13
handler.py CHANGED
@@ -52,12 +52,12 @@ class EndpointHandler:
52
  inputs = data.pop("inputs", data)
53
  parameters = data.pop("parameters", {})
54
 
55
- # Default generation parameters with safe values
56
  max_new_tokens = parameters.get("max_new_tokens", 512)
57
- temperature = max(0.1, min(parameters.get("temperature", 0.8), 2.0)) # Clamp between 0.1 and 2.0
58
- top_p = max(0.1, min(parameters.get("top_p", 0.95), 1.0)) # Clamp between 0.1 and 1.0
59
- do_sample = parameters.get("do_sample", True)
60
- repetition_penalty = max(1.0, min(parameters.get("repetition_penalty", 1.05), 2.0)) # Clamp between 1.0 and 2.0
61
 
62
  # Tokenize input
63
  input_ids = self.tokenizer(
@@ -67,20 +67,15 @@ class EndpointHandler:
67
  max_length=self.model.config.max_position_embeddings - max_new_tokens
68
  ).input_ids.to(self.model.device)
69
 
70
- # Generate response with safe parameters
71
  with torch.no_grad():
72
  outputs = self.model.generate(
73
  input_ids,
74
  max_new_tokens=max_new_tokens,
75
- temperature=temperature,
76
- top_p=top_p,
77
- top_k=50, # Add top_k for stability
78
- do_sample=do_sample,
79
- repetition_penalty=repetition_penalty,
80
  pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
81
  eos_token_id=self.tokenizer.eos_token_id,
82
- bad_words_ids=None, # Ensure no bad words restriction causing issues
83
- min_length=1, # Ensure at least 1 token is generated
84
  )
85
 
86
  # Decode output
 
52
  inputs = data.pop("inputs", data)
53
  parameters = data.pop("parameters", {})
54
 
55
+ # Default generation parameters - use greedy decoding to avoid sampling issues
56
  max_new_tokens = parameters.get("max_new_tokens", 512)
57
+
58
+ # Use greedy decoding (do_sample=False) to avoid probability tensor issues
59
+ # This is more stable for models with potential embedding issues
60
+ do_sample = False # Force greedy decoding
61
 
62
  # Tokenize input
63
  input_ids = self.tokenizer(
 
67
  max_length=self.model.config.max_position_embeddings - max_new_tokens
68
  ).input_ids.to(self.model.device)
69
 
70
+ # Generate response with greedy decoding (no sampling)
71
  with torch.no_grad():
72
  outputs = self.model.generate(
73
  input_ids,
74
  max_new_tokens=max_new_tokens,
75
+ do_sample=False, # Greedy decoding - most stable
76
+ num_beams=1, # No beam search for speed
 
 
 
77
  pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
78
  eos_token_id=self.tokenizer.eos_token_id,
 
 
79
  )
80
 
81
  # Decode output