Spaces:
Sleeping
Sleeping
Commit ·
85aece1
1
Parent(s): f722db4
Refactor start token handling in compute_entropy function
Browse files
app.py
CHANGED
|
@@ -46,11 +46,16 @@ def compute_entropy(code: str):
|
|
| 46 |
if attention_mask is not None:
|
| 47 |
attention_mask = attention_mask.to(device)
|
| 48 |
|
| 49 |
-
# Prepend
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
if attention_mask is not None:
|
| 55 |
attention_mask = torch.cat(
|
| 56 |
[torch.ones(1, 1, dtype=attention_mask.dtype, device=device), attention_mask], dim=1
|
|
|
|
| 46 |
if attention_mask is not None:
|
| 47 |
attention_mask = attention_mask.to(device)
|
| 48 |
|
| 49 |
+
# Prepend a start token so the first real token gets a predicted probability.
|
| 50 |
+
# Prefer bos_token_id, fall back to eos/pad (pad is always set above).
|
| 51 |
+
start_id = next(
|
| 52 |
+
(getattr(TOKENIZER, a) for a in ("bos_token_id", "eos_token_id", "pad_token_id")
|
| 53 |
+
if getattr(TOKENIZER, a, None) is not None),
|
| 54 |
+
None,
|
| 55 |
+
)
|
| 56 |
+
if start_id is not None and input_ids[0, 0].item() != start_id:
|
| 57 |
+
start_tensor = torch.full((1, 1), start_id, dtype=input_ids.dtype, device=device)
|
| 58 |
+
input_ids = torch.cat([start_tensor, input_ids], dim=1)
|
| 59 |
if attention_mask is not None:
|
| 60 |
attention_mask = torch.cat(
|
| 61 |
[torch.ones(1, 1, dtype=attention_mask.dtype, device=device), attention_mask], dim=1
|