ejschwartz commited on
Commit
85aece1
·
1 Parent(s): f722db4

Refactor start token handling in compute_entropy function

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -46,11 +46,16 @@ def compute_entropy(code: str):
46
  if attention_mask is not None:
47
  attention_mask = attention_mask.to(device)
48
 
49
- # Prepend BOS if not already present so the first real token gets a predicted probability
50
- bos_id = TOKENIZER.bos_token_id
51
- if bos_id is not None and input_ids[0, 0].item() != bos_id:
52
- bos_tensor = torch.full((1, 1), bos_id, dtype=input_ids.dtype, device=device)
53
- input_ids = torch.cat([bos_tensor, input_ids], dim=1)
 
 
 
 
 
54
  if attention_mask is not None:
55
  attention_mask = torch.cat(
56
  [torch.ones(1, 1, dtype=attention_mask.dtype, device=device), attention_mask], dim=1
 
46
  if attention_mask is not None:
47
  attention_mask = attention_mask.to(device)
48
 
49
+ # Prepend a start token so the first real token gets a predicted probability.
50
+ # Prefer bos_token_id, fall back to eos/pad (pad is always set above).
51
+ start_id = next(
52
+ (getattr(TOKENIZER, a) for a in ("bos_token_id", "eos_token_id", "pad_token_id")
53
+ if getattr(TOKENIZER, a, None) is not None),
54
+ None,
55
+ )
56
+ if start_id is not None and input_ids[0, 0].item() != start_id:
57
+ start_tensor = torch.full((1, 1), start_id, dtype=input_ids.dtype, device=device)
58
+ input_ids = torch.cat([start_tensor, input_ids], dim=1)
59
  if attention_mask is not None:
60
  attention_mask = torch.cat(
61
  [torch.ones(1, 1, dtype=attention_mask.dtype, device=device), attention_mask], dim=1