leejunhyeok commited on
Commit
d63f78d
·
verified ·
1 Parent(s): 4ea8a04

change n-gram and ratio check order

Browse files
Files changed (1) hide show
  1. logit_processors/logit_.py +11 -11
logit_processors/logit_.py CHANGED
@@ -38,24 +38,24 @@ class ThinkLogitsProcessor:
38
  logits: torch.Tensor
39
  ) -> torch.Tensor:
40
  if self.think_end_token not in past_token_ids:
 
 
 
 
 
 
 
 
 
41
 
42
  # ngram
43
- if len(past_token_ids) % self.interval == 0:
44
  # If repetation detected, force </think>
45
  if self.find_repeated_ngrams(past_token_ids, n=CHUNK_SIZE):
46
  # Set all other logits to -inf except for </think>
47
  logits = torch.full_like(logits, torch.finfo(logits.dtype).min)
48
  logits[self.think_end_token] = 1.0
49
- else:
50
- # ratio
51
- tokens_since_think = len(past_token_ids)
52
-
53
- response_budget = max(self.min_answer_budget, int((self.max_len - len(prompt_token_ids)) * (1-self.ratio)))
54
- remaining_budget = self.max_len - len(prompt_token_ids) - response_budget - tokens_since_think
55
-
56
- if 0 >= remaining_budget:
57
- logits = torch.full_like(logits, torch.finfo(logits.dtype).min)
58
- logits[self.think_end_token] = 1.0
59
  return logits
60
 
61
 
 
38
  logits: torch.Tensor
39
  ) -> torch.Tensor:
40
  if self.think_end_token not in past_token_ids:
41
+ # ratio
42
+ tokens_since_think = len(past_token_ids)
43
+
44
+ response_budget = max(self.min_answer_budget, int((self.max_len - len(prompt_token_ids)) * (1-self.ratio)))
45
+ remaining_budget = self.max_len - len(prompt_token_ids) - response_budget - tokens_since_think
46
+
47
+ if 0 >= remaining_budget:
48
+ logits = torch.full_like(logits, torch.finfo(logits.dtype).min)
49
+ logits[self.think_end_token] = 1.0
50
 
51
  # ngram
52
+ elif len(past_token_ids) % self.interval == 0:
53
  # If repetation detected, force </think>
54
  if self.find_repeated_ngrams(past_token_ids, n=CHUNK_SIZE):
55
  # Set all other logits to -inf except for </think>
56
  logits = torch.full_like(logits, torch.finfo(logits.dtype).min)
57
  logits[self.think_end_token] = 1.0
58
+
 
 
 
 
 
 
 
 
 
59
  return logits
60
 
61