Spaces:
Running
Running
Commit ·
1d7752e
1
Parent(s): b52538e
Solutions and notebook
Browse files- forbidden_solution.py +142 -0
- solution.py +80 -54
- tokenizers_and_decoding.ipynb +984 -0
forbidden_solution.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import LogitsProcessor, AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
|
| 4 |
+
# --- EXERCISE 1: La disparition (No 'e' or 'E) ---
|
| 5 |
+
|
| 6 |
+
# --- Logits Processor to forbid specific tokens ---
|
| 7 |
+
class ForbidTokensLogitsProcessor(LogitsProcessor):
|
| 8 |
+
"""Logits processor that sets forbidden token logits to -inf."""
|
| 9 |
+
def __init__(self, forbidden_token_ids):
|
| 10 |
+
self.forbidden_token_ids = list(forbidden_token_ids)
|
| 11 |
+
|
| 12 |
+
def __call__(self, input_ids, scores):
|
| 13 |
+
scores[:, self.forbidden_token_ids] = float('-inf')
|
| 14 |
+
return scores
|
| 15 |
+
|
| 16 |
+
class LaDisparition:
|
| 17 |
+
"""Generate text without ever using the letter 'e' or 'E' using model.generate()."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, model, tokenizer, debug=False):
|
| 20 |
+
self.model = model
|
| 21 |
+
self.tokenizer = tokenizer
|
| 22 |
+
self.debug = debug
|
| 23 |
+
# Pre-calculate forbidden token IDs (tokens containing 'e', 'E', or non-ASCII)
|
| 24 |
+
self.forbidden_token_ids = set()
|
| 25 |
+
for token_id in range(len(tokenizer.get_vocab())):
|
| 26 |
+
decoded = tokenizer.decode([token_id])
|
| 27 |
+
if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded):
|
| 28 |
+
self.forbidden_token_ids.add(token_id)
|
| 29 |
+
|
| 30 |
+
self.processor = ForbidTokensLogitsProcessor(self.forbidden_token_ids)
|
| 31 |
+
|
| 32 |
+
def __call__(self, prompt, max_tokens=30, beam_width=5):
|
| 33 |
+
# Option 1: we use self.tokenizer to tokenize the prompt
|
| 34 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", return_attention_mask=True).to(self.model.device)
|
| 35 |
+
|
| 36 |
+
outputs = self.model.generate(
|
| 37 |
+
**inputs,
|
| 38 |
+
max_new_tokens=max_tokens,
|
| 39 |
+
num_beams=beam_width,
|
| 40 |
+
logits_processor=[self.processor],
|
| 41 |
+
do_sample=False
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# --- EXERCISE 2: The Toulouse Sequence ---
|
| 48 |
+
|
| 49 |
+
class ForbidToulousePrefixLogitsProcessor(LogitsProcessor):
|
| 50 |
+
"""
|
| 51 |
+
When generating, we store the largest suffix since whitespace.
|
| 52 |
+
We mask out all tokens that if added would lead to a prefix of "Toulouse" of length at least 4.
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, tokenizer):
|
| 55 |
+
self.tokenizer = tokenizer
|
| 56 |
+
self.forbidden_word = "toulouse"
|
| 57 |
+
self.min_prefix_len = 4
|
| 58 |
+
|
| 59 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 60 |
+
current_sequence_ids = input_ids[0]
|
| 61 |
+
|
| 62 |
+
# Decode the current sequence to find the last word
|
| 63 |
+
decoded_sequence = self.tokenizer.decode(current_sequence_ids)
|
| 64 |
+
|
| 65 |
+
# Find the start of the last word (suffix since the last non-alphabetical character)
|
| 66 |
+
last_separator_idx = -1
|
| 67 |
+
for i in range(len(decoded_sequence) - 1, -1, -1):
|
| 68 |
+
if not decoded_sequence[i].isalpha():
|
| 69 |
+
last_separator_idx = i
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
if last_separator_idx != -1:
|
| 73 |
+
current_word_prefix = decoded_sequence[last_separator_idx + 1:]
|
| 74 |
+
else:
|
| 75 |
+
current_word_prefix = decoded_sequence
|
| 76 |
+
|
| 77 |
+
# If the current word prefix is empty, we don't need to check anything yet
|
| 78 |
+
if not current_word_prefix:
|
| 79 |
+
return scores
|
| 80 |
+
|
| 81 |
+
# print(f"Current word prefix: '{current_word_prefix}'")
|
| 82 |
+
|
| 83 |
+
# Get the token IDs for the current word prefix to avoid re-tokenizing the whole sequence
|
| 84 |
+
current_word_ids = self.tokenizer.encode(current_word_prefix, add_special_tokens=False)
|
| 85 |
+
|
| 86 |
+
# Iterate over all possible next tokens
|
| 87 |
+
for token_id in range(scores.shape[1]):
|
| 88 |
+
# Create a hypothetical next word by adding the candidate token
|
| 89 |
+
hypothetical_word_ids = current_word_ids + [token_id]
|
| 90 |
+
hypothetical_word = self.tokenizer.decode(hypothetical_word_ids)
|
| 91 |
+
|
| 92 |
+
# Check if the hypothetical word is a forbidden prefix
|
| 93 |
+
# We check against the lowercase version for case-insensitivity
|
| 94 |
+
if len(hypothetical_word) >= self.min_prefix_len and \
|
| 95 |
+
self.forbidden_word.startswith(hypothetical_word.lower()):
|
| 96 |
+
scores[0, token_id] = float('-inf')
|
| 97 |
+
# print(f"Forbidden prefix: '{hypothetical_word}'")
|
| 98 |
+
|
| 99 |
+
return scores
|
| 100 |
+
|
| 101 |
+
class ToulouseSequence:
|
| 102 |
+
"""Generate text without ever using the word 'Toulouse' using model.generate()."""
|
| 103 |
+
|
| 104 |
+
def __init__(self, model, tokenizer, debug=False):
|
| 105 |
+
self.model = model
|
| 106 |
+
self.tokenizer = tokenizer
|
| 107 |
+
self.debug = debug
|
| 108 |
+
# Use the new processor for the "Toulouse" prefix strategy
|
| 109 |
+
self.processor = ForbidToulousePrefixLogitsProcessor(self.tokenizer)
|
| 110 |
+
|
| 111 |
+
def __call__(self, prompt, max_tokens=100):
|
| 112 |
+
# Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
|
| 113 |
+
message = [{"role": "user", "content": prompt}]
|
| 114 |
+
inputs = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
|
| 115 |
+
|
| 116 |
+
# Create an attention mask for the inputs
|
| 117 |
+
attention_mask = torch.ones_like(inputs)
|
| 118 |
+
|
| 119 |
+
prompt_length = inputs.shape[1]
|
| 120 |
+
|
| 121 |
+
outputs = self.model.generate(
|
| 122 |
+
inputs,
|
| 123 |
+
attention_mask=attention_mask,
|
| 124 |
+
max_new_tokens=max_tokens,
|
| 125 |
+
logits_processor=[self.processor],
|
| 126 |
+
do_sample=False
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Return only the generated part
|
| 130 |
+
generated_tokens = outputs[0][prompt_length:]
|
| 131 |
+
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
# NOTE: This block is for testing only. The evaluation server provides model and tokenizer.
|
| 135 |
+
# SETUP
|
| 136 |
+
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 137 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 138 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.float32, device_map="auto")
|
| 139 |
+
la_disparition_generator = LaDisparition(model, tokenizer)
|
| 140 |
+
print("Ex 1 (No 'e'):", la_disparition_generator("Who are you?"))
|
| 141 |
+
toulouse_sequence_generator = ToulouseSequence(model, tokenizer)
|
| 142 |
+
print("Ex 2 (No 'Toulouse'):", toulouse_sequence_generator("Where is Toulouse?"))
|
solution.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
@@ -11,7 +11,7 @@ class LaDisparition:
|
|
| 11 |
You need to manually adjust the logits to forbid tokens containing 'e' or 'E'.
|
| 12 |
REQUIREMENT: Do NOT use model.generate().
|
| 13 |
"""
|
| 14 |
-
def __init__(self, model, tokenizer, debug=False):
|
| 15 |
self.model = model
|
| 16 |
self.tokenizer = tokenizer
|
| 17 |
self.debug = debug
|
|
@@ -22,15 +22,12 @@ class LaDisparition:
|
|
| 22 |
for token_id in range(len(vocab)):
|
| 23 |
# Decode the token to see what it actually produces
|
| 24 |
decoded = self.tokenizer.decode([token_id])
|
| 25 |
-
# Forbid if contains 'e'/'E' or contains non-ASCII (which might hide 'e'
|
| 26 |
if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded):
|
| 27 |
-
self.forbidden_token_ids.add(token_id)
|
| 28 |
-
|
| 29 |
-
# Warning: The evaluation server uses a different model and tokenizer than the template. Do not hard-code Token IDs. Use self.tokenizer.get_vocab() or self.tokenizer.encode() to find the IDs relevant to the current model.
|
| 30 |
-
|
| 31 |
|
| 32 |
-
def __call__(self, prompt, max_tokens=
|
| 33 |
-
#
|
| 34 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 35 |
input_ids = inputs["input_ids"]
|
| 36 |
prompt_len = input_ids.shape[1]
|
|
@@ -53,10 +50,7 @@ class LaDisparition:
|
|
| 53 |
# Create mask for forbidden tokens
|
| 54 |
forbidden_mask = torch.zeros_like(logits, dtype=torch.bool)
|
| 55 |
forbidden_mask[list(self.forbidden_token_ids)] = True
|
| 56 |
-
|
| 57 |
-
# Set forbidden tokens to a very negative value (safe for float16)
|
| 58 |
-
logits[forbidden_mask] = torch.finfo(logits.dtype).min / 2
|
| 59 |
-
|
| 60 |
# Convert to log probabilities
|
| 61 |
log_probs = F.log_softmax(logits, dim=-1)
|
| 62 |
|
|
@@ -73,7 +67,6 @@ class LaDisparition:
|
|
| 73 |
|
| 74 |
for token_id, token_log_prob in zip(top_indices.tolist(), top_log_probs.tolist()):
|
| 75 |
if token_id == self.tokenizer.eos_token_id:
|
| 76 |
-
# Add as candidate with bonus for finishing
|
| 77 |
candidates.append((seq, log_prob + token_log_prob))
|
| 78 |
else:
|
| 79 |
candidates.append((seq + [token_id], log_prob + token_log_prob))
|
|
@@ -103,38 +96,69 @@ class ToulouseSequence:
|
|
| 103 |
"""
|
| 104 |
Generate text without ever using the word 'Toulouse'.
|
| 105 |
For this, you must use model() directly: model(input_ids) yields logits.
|
| 106 |
-
|
| 107 |
REQUIREMENT: Do NOT use model.generate().
|
| 108 |
"""
|
| 109 |
def __init__(self, model, tokenizer, debug=False):
|
| 110 |
self.model = model
|
| 111 |
self.tokenizer = tokenizer
|
| 112 |
self.debug = debug
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
def __call__(self, prompt, max_tokens=20):
|
| 131 |
-
#
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
prompt_length =
|
| 135 |
|
| 136 |
-
# Generate tokens one by one
|
| 137 |
-
seq =
|
| 138 |
|
| 139 |
for step in range(max_tokens):
|
| 140 |
input_tensor = torch.tensor([seq], device=self.model.device)
|
|
@@ -144,16 +168,19 @@ class ToulouseSequence:
|
|
| 144 |
outputs = self.model(input_tensor)
|
| 145 |
logits = outputs.logits[0, -1, :].clone()
|
| 146 |
|
| 147 |
-
#
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
#
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
#
|
| 155 |
-
|
| 156 |
-
|
|
|
|
| 157 |
seq.append(next_token)
|
| 158 |
|
| 159 |
# Extract only the generated tokens (skip the input prompt tokens)
|
|
@@ -161,14 +188,13 @@ class ToulouseSequence:
|
|
| 161 |
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 162 |
return generated_text.strip()
|
| 163 |
|
| 164 |
-
|
| 165 |
# NOTE: This block is for testing only. The evaluation server provides model and tokenizer.
|
| 166 |
-
#
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
# print("Ex 2 (No 'Toulouse'):", toulouse_sequence_generator("Where is the headquarters of Airbus located?"))
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 11 |
You need to manually adjust the logits to forbid tokens containing 'e' or 'E'.
|
| 12 |
REQUIREMENT: Do NOT use model.generate().
|
| 13 |
"""
|
| 14 |
+
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, debug: bool = False):
|
| 15 |
self.model = model
|
| 16 |
self.tokenizer = tokenizer
|
| 17 |
self.debug = debug
|
|
|
|
| 22 |
for token_id in range(len(vocab)):
|
| 23 |
# Decode the token to see what it actually produces
|
| 24 |
decoded = self.tokenizer.decode([token_id])
|
| 25 |
+
# Forbid if contains 'e'/'E' or contains non-ASCII (which might hide 'e')
|
| 26 |
if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded):
|
| 27 |
+
self.forbidden_token_ids.add(token_id)
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
def __call__(self, prompt, max_tokens=20, beam_width=5):
|
| 30 |
+
# Option 1: we use self.tokenizer to tokenize the prompt
|
| 31 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 32 |
input_ids = inputs["input_ids"]
|
| 33 |
prompt_len = input_ids.shape[1]
|
|
|
|
| 50 |
# Create mask for forbidden tokens
|
| 51 |
forbidden_mask = torch.zeros_like(logits, dtype=torch.bool)
|
| 52 |
forbidden_mask[list(self.forbidden_token_ids)] = True
|
| 53 |
+
|
|
|
|
|
|
|
|
|
|
| 54 |
# Convert to log probabilities
|
| 55 |
log_probs = F.log_softmax(logits, dim=-1)
|
| 56 |
|
|
|
|
| 67 |
|
| 68 |
for token_id, token_log_prob in zip(top_indices.tolist(), top_log_probs.tolist()):
|
| 69 |
if token_id == self.tokenizer.eos_token_id:
|
|
|
|
| 70 |
candidates.append((seq, log_prob + token_log_prob))
|
| 71 |
else:
|
| 72 |
candidates.append((seq + [token_id], log_prob + token_log_prob))
|
|
|
|
| 96 |
"""
|
| 97 |
Generate text without ever using the word 'Toulouse'.
|
| 98 |
For this, you must use model() directly: model(input_ids) yields logits.
|
| 99 |
+
We mask out all tokens that if added would lead to a prefix of "Toulouse" of length at least 4.
|
| 100 |
REQUIREMENT: Do NOT use model.generate().
|
| 101 |
"""
|
| 102 |
def __init__(self, model, tokenizer, debug=False):
|
| 103 |
self.model = model
|
| 104 |
self.tokenizer = tokenizer
|
| 105 |
self.debug = debug
|
| 106 |
+
self.forbidden_word = "Toulouse"
|
| 107 |
+
self.min_prefix_len = 4
|
| 108 |
+
|
| 109 |
+
def _get_current_word_prefix(self, decoded_sequence: str) -> str:
|
| 110 |
+
"""Find the suffix since the last non-alphabetical character."""
|
| 111 |
+
last_separator_idx = -1
|
| 112 |
+
for i in range(len(decoded_sequence) - 1, -1, -1):
|
| 113 |
+
if not decoded_sequence[i].isalpha():
|
| 114 |
+
last_separator_idx = i
|
| 115 |
+
break
|
| 116 |
|
| 117 |
+
if last_separator_idx != -1:
|
| 118 |
+
return decoded_sequence[last_separator_idx + 1:]
|
| 119 |
+
else:
|
| 120 |
+
return decoded_sequence
|
| 121 |
+
|
| 122 |
+
def _get_forbidden_mask(self, seq: List[int]) -> torch.Tensor:
|
| 123 |
+
"""
|
| 124 |
+
Create a mask for tokens that would create a forbidden prefix of 'Toulouse'.
|
| 125 |
+
Returns a boolean tensor where True means the token should be forbidden.
|
| 126 |
+
"""
|
| 127 |
+
vocab_size = len(self.tokenizer.get_vocab())
|
| 128 |
+
forbidden_mask = torch.zeros(vocab_size, dtype=torch.bool, device=self.model.device)
|
| 129 |
|
| 130 |
+
# Decode the current sequence to find the current word prefix
|
| 131 |
+
decoded_sequence = self.tokenizer.decode(seq)
|
| 132 |
+
current_word_prefix = self._get_current_word_prefix(decoded_sequence)
|
| 133 |
+
|
| 134 |
+
# If the current word prefix is empty, we don't need to check anything yet
|
| 135 |
+
if not current_word_prefix:
|
| 136 |
+
return forbidden_mask
|
| 137 |
+
|
| 138 |
+
# Get the token IDs for the current word prefix
|
| 139 |
+
current_word_ids = self.tokenizer.encode(current_word_prefix, add_special_tokens=False)
|
| 140 |
+
|
| 141 |
+
# Iterate over all possible next tokens
|
| 142 |
+
for token_id in range(vocab_size):
|
| 143 |
+
# Create a hypothetical next word by adding the candidate token
|
| 144 |
+
hypothetical_word_ids = current_word_ids + [token_id]
|
| 145 |
+
hypothetical_word = self.tokenizer.decode(hypothetical_word_ids)
|
| 146 |
+
|
| 147 |
+
# Check if the hypothetical word is a forbidden prefix (case-insensitive)
|
| 148 |
+
if len(hypothetical_word) >= self.min_prefix_len and \
|
| 149 |
+
self.forbidden_word.lower().startswith(hypothetical_word.lower()):
|
| 150 |
+
forbidden_mask[token_id] = True
|
| 151 |
+
|
| 152 |
+
return forbidden_mask
|
| 153 |
|
| 154 |
def __call__(self, prompt, max_tokens=20):
|
| 155 |
+
# Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
|
| 156 |
+
message = [{"role": "user", "content": prompt}]
|
| 157 |
+
inputs = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
|
| 158 |
+
prompt_length = inputs.shape[1]
|
| 159 |
|
| 160 |
+
# Generate tokens one by one
|
| 161 |
+
seq = inputs[0].tolist()
|
| 162 |
|
| 163 |
for step in range(max_tokens):
|
| 164 |
input_tensor = torch.tensor([seq], device=self.model.device)
|
|
|
|
| 168 |
outputs = self.model(input_tensor)
|
| 169 |
logits = outputs.logits[0, -1, :].clone()
|
| 170 |
|
| 171 |
+
# Get forbidden mask based on current word prefix
|
| 172 |
+
forbidden_mask = self._get_forbidden_mask(seq)
|
| 173 |
+
|
| 174 |
+
# Apply the mask: set forbidden tokens to -inf
|
| 175 |
+
logits[forbidden_mask] = float('-inf')
|
| 176 |
|
| 177 |
+
# Greedy decoding
|
| 178 |
+
next_token = torch.argmax(logits).item()
|
| 179 |
+
|
| 180 |
+
# Stop if EOS token
|
| 181 |
+
if next_token == self.tokenizer.eos_token_id:
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
seq.append(next_token)
|
| 185 |
|
| 186 |
# Extract only the generated tokens (skip the input prompt tokens)
|
|
|
|
| 188 |
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 189 |
return generated_text.strip()
|
| 190 |
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
# NOTE: This block is for testing only. The evaluation server provides model and tokenizer.
|
| 193 |
+
# SETUP
|
| 194 |
+
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 195 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 196 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
|
| 197 |
+
la_disparition_generator = LaDisparition(model, tokenizer)
|
| 198 |
+
print("Ex 1 (No 'e'):", la_disparition_generator("Who are you?"))
|
| 199 |
+
toulouse_sequence_generator = ToulouseSequence(model, tokenizer, debug=True)
|
| 200 |
+
print("Ex 2 (No 'Toulouse'):", toulouse_sequence_generator("Where is Toulouse?"))
|
|
|
tokenizers_and_decoding.ipynb
ADDED
|
@@ -0,0 +1,984 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "5815a5fe",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"# Install dependencies if needed\n",
|
| 11 |
+
"# !pip install transformers torch"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 2,
|
| 17 |
+
"id": "e4246fec",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": [
|
| 21 |
+
"import torch\n",
|
| 22 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM "
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "markdown",
|
| 27 |
+
"id": "0730b798",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"source": [
|
| 30 |
+
"---\n",
|
| 31 |
+
"# Part 1: Tokenizers\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"Tokenizers convert text into numerical representations (token IDs) that models can process."
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": 3,
|
| 39 |
+
"id": "ba944d7f",
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [
|
| 42 |
+
{
|
| 43 |
+
"name": "stdout",
|
| 44 |
+
"output_type": "stream",
|
| 45 |
+
"text": [
|
| 46 |
+
"Vocabulary size: 32000\n"
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
],
|
| 50 |
+
"source": [
|
| 51 |
+
"# Load a tokenizer\n",
|
| 52 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\")\n",
|
| 53 |
+
"print(f\"Vocabulary size: {len(tokenizer)}\")"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "markdown",
|
| 58 |
+
"id": "9a3e5f5a",
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"source": [
|
| 61 |
+
"## 1.1 Basic Encoding & Decoding\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"- `tokenizer.encode()` converts text → token IDs\n",
|
| 64 |
+
"- `tokenizer.decode()` converts token IDs → text"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": 4,
|
| 70 |
+
"id": "8d9e1208",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"outputs": [
|
| 73 |
+
{
|
| 74 |
+
"name": "stdout",
|
| 75 |
+
"output_type": "stream",
|
| 76 |
+
"text": [
|
| 77 |
+
"Text: 'Hello, how are you?'\n",
|
| 78 |
+
"Token IDs: [1, 15043, 29892, 920, 526, 366, 29973]\n",
|
| 79 |
+
"Number of tokens: 7\n"
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
],
|
| 83 |
+
"source": [
|
| 84 |
+
"text = \"Hello, how are you?\"\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"# Encode: text -> token IDs\n",
|
| 87 |
+
"token_ids = tokenizer.encode(text)\n",
|
| 88 |
+
"print(f\"Text: '{text}'\")\n",
|
| 89 |
+
"print(f\"Token IDs: {token_ids}\")\n",
|
| 90 |
+
"print(f\"Number of tokens: {len(token_ids)}\")"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": 5,
|
| 96 |
+
"id": "506574b3",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [
|
| 99 |
+
{
|
| 100 |
+
"name": "stdout",
|
| 101 |
+
"output_type": "stream",
|
| 102 |
+
"text": [
|
| 103 |
+
"Decoded: '<s> Hello, how are you?'\n",
|
| 104 |
+
"Decoded (no special tokens): 'Hello, how are you?'\n"
|
| 105 |
+
]
|
| 106 |
+
}
|
| 107 |
+
],
|
| 108 |
+
"source": [
|
| 109 |
+
"# Decode: token IDs -> text\n",
|
| 110 |
+
"decoded_text = tokenizer.decode(token_ids)\n",
|
| 111 |
+
"print(f\"Decoded: '{decoded_text}'\")\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"# Skip special tokens (like <s>, </s>)\n",
|
| 114 |
+
"decoded_clean = tokenizer.decode(token_ids, skip_special_tokens=True)\n",
|
| 115 |
+
"print(f\"Decoded (no special tokens): '{decoded_clean}'\")"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"cell_type": "code",
|
| 120 |
+
"execution_count": 6,
|
| 121 |
+
"id": "278fd636",
|
| 122 |
+
"metadata": {},
|
| 123 |
+
"outputs": [
|
| 124 |
+
{
|
| 125 |
+
"name": "stdout",
|
| 126 |
+
"output_type": "stream",
|
| 127 |
+
"text": [
|
| 128 |
+
" 1 -> '<s>'\n",
|
| 129 |
+
" 15043 -> '▁Hello'\n",
|
| 130 |
+
" 29892 -> ','\n",
|
| 131 |
+
" 920 -> '▁how'\n",
|
| 132 |
+
" 526 -> '▁are'\n",
|
| 133 |
+
" 366 -> '▁you'\n",
|
| 134 |
+
" 29973 -> '?'\n"
|
| 135 |
+
]
|
| 136 |
+
}
|
| 137 |
+
],
|
| 138 |
+
"source": [
|
| 139 |
+
"# Look at individual tokens\n",
|
| 140 |
+
"tokens = tokenizer.convert_ids_to_tokens(token_ids)\n",
|
| 141 |
+
"for tid, tok in zip(token_ids, tokens):\n",
|
| 142 |
+
" print(f\" {tid:5d} -> '{tok}'\")"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "markdown",
|
| 147 |
+
"id": "aaafadbc",
|
| 148 |
+
"metadata": {},
|
| 149 |
+
"source": [
|
| 150 |
+
"### Key insight: Subword tokenization\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"Words are split into subwords. Common words stay whole, rare words are broken into pieces."
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": 7,
|
| 158 |
+
"id": "4a7a590d",
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [
|
| 161 |
+
{
|
| 162 |
+
"name": "stdout",
|
| 163 |
+
"output_type": "stream",
|
| 164 |
+
"text": [
|
| 165 |
+
"'cat' -> 1 token(s): ['▁cat']\n",
|
| 166 |
+
"'running' -> 1 token(s): ['▁running']\n",
|
| 167 |
+
"'internationalization' -> 2 token(s): ['▁international', 'ization']\n",
|
| 168 |
+
"'TinyLlama' -> 5 token(s): ['▁T', 'iny', 'L', 'l', 'ama']\n"
|
| 169 |
+
]
|
| 170 |
+
}
|
| 171 |
+
],
|
| 172 |
+
"source": [
|
| 173 |
+
"# Compare tokenization of common vs rare words\n",
|
| 174 |
+
"words = [\"cat\", \"running\", \"internationalization\", \"TinyLlama\"]\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"for word in words:\n",
|
| 177 |
+
" ids = tokenizer.encode(word, add_special_tokens=False)\n",
|
| 178 |
+
" tokens = tokenizer.convert_ids_to_tokens(ids)\n",
|
| 179 |
+
" print(f\"'{word}' -> {len(ids)} token(s): {tokens}\")"
|
| 180 |
+
]
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"cell_type": "markdown",
|
| 184 |
+
"id": "7376eda8",
|
| 185 |
+
"metadata": {},
|
| 186 |
+
"source": [
|
| 187 |
+
"## 1.2 Batching with Padding\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"When processing multiple texts, they need the same length. Padding adds special tokens to shorter sequences."
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"cell_type": "code",
|
| 194 |
+
"execution_count": 8,
|
| 195 |
+
"id": "a4086c5d",
|
| 196 |
+
"metadata": {},
|
| 197 |
+
"outputs": [
|
| 198 |
+
{
|
| 199 |
+
"name": "stdout",
|
| 200 |
+
"output_type": "stream",
|
| 201 |
+
"text": [
|
| 202 |
+
"'Hello!' -> 3 tokens\n",
|
| 203 |
+
"'How are you doing today?' -> 7 tokens\n",
|
| 204 |
+
"'I am fine.' -> 5 tokens\n"
|
| 205 |
+
]
|
| 206 |
+
}
|
| 207 |
+
],
|
| 208 |
+
"source": [
|
| 209 |
+
"# Multiple texts of different lengths\n",
|
| 210 |
+
"texts = [\n",
|
| 211 |
+
" \"Hello!\",\n",
|
| 212 |
+
" \"How are you doing today?\",\n",
|
| 213 |
+
" \"I am fine.\"\n",
|
| 214 |
+
"]\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"# Without padding - different lengths\n",
|
| 217 |
+
"for text in texts:\n",
|
| 218 |
+
" ids = tokenizer.encode(text)\n",
|
| 219 |
+
" print(f\"'{text}' -> {len(ids)} tokens\")"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "code",
|
| 224 |
+
"execution_count": null,
|
| 225 |
+
"id": "247725b1",
|
| 226 |
+
"metadata": {},
|
| 227 |
+
"outputs": [
|
| 228 |
+
{
|
| 229 |
+
"name": "stdout",
|
| 230 |
+
"output_type": "stream",
|
| 231 |
+
"text": [
|
| 232 |
+
"Input IDs shape: torch.Size([3, 7])\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"Padded sequences:\n",
|
| 235 |
+
" 0: [1, 15043, 29991, 2, 2, 2, 2]\n",
|
| 236 |
+
" 1: [1, 1128, 526, 366, 2599, 9826, 29973]\n",
|
| 237 |
+
" 2: [1, 306, 626, 2691, 29889, 2, 2]\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"Attention mask (1=real token, 0=padding):\n",
|
| 240 |
+
" 0: [1, 1, 1, 0, 0, 0, 0]\n",
|
| 241 |
+
" 1: [1, 1, 1, 1, 1, 1, 1]\n",
|
| 242 |
+
" 2: [1, 1, 1, 1, 1, 0, 0]\n"
|
| 243 |
+
]
|
| 244 |
+
}
|
| 245 |
+
],
|
| 246 |
+
"source": [
|
| 247 |
+
"# With padding - same length (use tokenizer() for batch processing)\n",
|
| 248 |
+
"# Set pad_token if not defined\n",
|
| 249 |
+
"if tokenizer.pad_token is None:\n",
|
| 250 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"batch = tokenizer(texts, padding=True, return_tensors=\"pt\")\n",
|
| 253 |
+
"# OLD (but equivalent) API:\n",
|
| 254 |
+
"# batch = tokenizer.batch_encode_plus(texts, padding=True, return_tensors=\"pt\")\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"print(\"Input IDs shape:\", batch[\"input_ids\"].shape)\n",
|
| 257 |
+
"print(\"\\nPadded sequences:\")\n",
|
| 258 |
+
"for i, (text, ids) in enumerate(zip(texts, batch[\"input_ids\"])):\n",
|
| 259 |
+
" print(f\" {i}: {ids.tolist()}\")\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"print(f\"\\nAttention mask (1=real token, 0=padding):\")\n",
|
| 262 |
+
"for i, mask in enumerate(batch[\"attention_mask\"]):\n",
|
| 263 |
+
" print(f\" {i}: {mask.tolist()}\")"
|
| 264 |
+
]
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"cell_type": "code",
|
| 268 |
+
"execution_count": 16,
|
| 269 |
+
"id": "0941dd4e",
|
| 270 |
+
"metadata": {},
|
| 271 |
+
"outputs": [
|
| 272 |
+
{
|
| 273 |
+
"data": {
|
| 274 |
+
"text/plain": [
|
| 275 |
+
"['<s> Hello!</s></s></s></s>',\n",
|
| 276 |
+
" '<s> How are you doing today?',\n",
|
| 277 |
+
" '<s> I am fine.</s></s>']"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
"execution_count": 16,
|
| 281 |
+
"metadata": {},
|
| 282 |
+
"output_type": "execute_result"
|
| 283 |
+
}
|
| 284 |
+
],
|
| 285 |
+
"source": [
|
| 286 |
+
"tokenizer.batch_decode(batch[\"input_ids\"], skip_special_tokens=False)"
|
| 287 |
+
]
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"cell_type": "markdown",
|
| 291 |
+
"id": "2bdba168",
|
| 292 |
+
"metadata": {},
|
| 293 |
+
"source": [
|
| 294 |
+
"## 1.3 Truncation\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"When text is too long, truncation cuts it to fit the model's maximum length."
|
| 297 |
+
]
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"cell_type": "code",
|
| 301 |
+
"execution_count": 17,
|
| 302 |
+
"id": "80b135bc",
|
| 303 |
+
"metadata": {},
|
| 304 |
+
"outputs": [
|
| 305 |
+
{
|
| 306 |
+
"name": "stdout",
|
| 307 |
+
"output_type": "stream",
|
| 308 |
+
"text": [
|
| 309 |
+
"Original text length: 3000 characters\n",
|
| 310 |
+
"Full tokenization: 702 tokens\n",
|
| 311 |
+
"Truncated to 50: 50 tokens\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"Truncated text: 'This is a very long sentence. This is a very long sentence. This is a very long sentence. This is a very long sentence. This is a very long sentence. This is a very long sentence. This is a very long sentence.'\n"
|
| 314 |
+
]
|
| 315 |
+
}
|
| 316 |
+
],
|
| 317 |
+
"source": [
|
| 318 |
+
"long_text = \"This is a very long sentence. \" * 100\n",
|
| 319 |
+
"print(f\"Original text length: {len(long_text)} characters\")\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"# Without truncation\n",
|
| 322 |
+
"ids_full = tokenizer.encode(long_text)\n",
|
| 323 |
+
"print(f\"Full tokenization: {len(ids_full)} tokens\")\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"# With truncation to max 50 tokens\n",
|
| 326 |
+
"ids_truncated = tokenizer.encode(long_text, max_length=50, truncation=True)\n",
|
| 327 |
+
"print(f\"Truncated to 50: {len(ids_truncated)} tokens\")\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"# Decode to see what was kept\n",
|
| 330 |
+
"print(f\"\\nTruncated text: '{tokenizer.decode(ids_truncated, skip_special_tokens=True)}'\")"
|
| 331 |
+
]
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"cell_type": "markdown",
|
| 335 |
+
"id": "8b514583",
|
| 336 |
+
"metadata": {},
|
| 337 |
+
"source": [
|
| 338 |
+
"## 1.4 Chat Templates\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"Chat models expect input in a specific format. Chat templates handle this automatically."
|
| 341 |
+
]
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"cell_type": "code",
|
| 345 |
+
"execution_count": 22,
|
| 346 |
+
"id": "19a079d2",
|
| 347 |
+
"metadata": {},
|
| 348 |
+
"outputs": [
|
| 349 |
+
{
|
| 350 |
+
"name": "stdout",
|
| 351 |
+
"output_type": "stream",
|
| 352 |
+
"text": [
|
| 353 |
+
"Formatted chat:\n",
|
| 354 |
+
"<|system|>\n",
|
| 355 |
+
"You are a helpful assistant.</s>\n",
|
| 356 |
+
"<|user|>\n",
|
| 357 |
+
"What is the capital of France?</s>\n",
|
| 358 |
+
"<|assistant|>\n",
|
| 359 |
+
"The capital of France is Paris.</s>\n",
|
| 360 |
+
"<|user|>\n",
|
| 361 |
+
"What about Germany?</s>\n",
|
| 362 |
+
"\n",
|
| 363 |
+
"Tokenized chat:\n",
|
| 364 |
+
"[529, 29989, 5205, 29989, 29958, 13, 3492, 526, 263, 8444, 20255, 29889, 2, 29871, 13, 29966, 29989, 1792, 29989, 29958, 13, 5618, 338, 278, 7483, 310, 3444, 29973, 2, 29871, 13, 29966, 29989, 465, 22137, 29989, 29958, 13, 1576, 7483, 310, 3444, 338, 3681, 29889, 2, 29871, 13, 29966, 29989, 1792, 29989, 29958, 13, 5618, 1048, 9556, 29973, 2, 29871, 13]\n"
|
| 365 |
+
]
|
| 366 |
+
}
|
| 367 |
+
],
|
| 368 |
+
"source": [
|
| 369 |
+
"# Chat messages in OpenAI-style format\n",
|
| 370 |
+
"messages = [\n",
|
| 371 |
+
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
|
| 372 |
+
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
|
| 373 |
+
" {\"role\": \"assistant\", \"content\": \"The capital of France is Paris.\"},\n",
|
| 374 |
+
" {\"role\": \"user\", \"content\": \"What about Germany?\"}\n",
|
| 375 |
+
"]\n",
|
| 376 |
+
"\n",
|
| 377 |
+
"# Apply chat template\n",
|
| 378 |
+
"formatted = tokenizer.apply_chat_template(messages, tokenize=False)\n",
|
| 379 |
+
"print(\"Formatted chat:\")\n",
|
| 380 |
+
"print(formatted)\n",
|
| 381 |
+
"# Tokenize formatted chat\n",
|
| 382 |
+
"formatted_ids = tokenizer.apply_chat_template(messages)\n",
|
| 383 |
+
"print(\"Tokenized chat:\")\n",
|
| 384 |
+
"print(formatted_ids)\n"
|
| 385 |
+
]
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"cell_type": "code",
|
| 389 |
+
"execution_count": 23,
|
| 390 |
+
"id": "289fa92d",
|
| 391 |
+
"metadata": {},
|
| 392 |
+
"outputs": [
|
| 393 |
+
{
|
| 394 |
+
"name": "stdout",
|
| 395 |
+
"output_type": "stream",
|
| 396 |
+
"text": [
|
| 397 |
+
"Token IDs shape: torch.Size([1, 68])\n",
|
| 398 |
+
"Number of tokens: 68\n"
|
| 399 |
+
]
|
| 400 |
+
}
|
| 401 |
+
],
|
| 402 |
+
"source": [
|
| 403 |
+
"# Tokenize directly with chat template\n",
|
| 404 |
+
"inputs = tokenizer.apply_chat_template(\n",
|
| 405 |
+
" messages, \n",
|
| 406 |
+
" tokenize=True, \n",
|
| 407 |
+
" return_tensors=\"pt\",\n",
|
| 408 |
+
" add_generation_prompt=True # Add prompt for assistant to continue\n",
|
| 409 |
+
")\n",
|
| 410 |
+
"print(f\"Token IDs shape: {inputs.shape}\")\n",
|
| 411 |
+
"print(f\"Number of tokens: {inputs.shape[1]}\")"
|
| 412 |
+
]
|
| 413 |
+
},
|
| 414 |
+
{
|
| 415 |
+
"cell_type": "code",
|
| 416 |
+
"execution_count": null,
|
| 417 |
+
"id": "874ceaae",
|
| 418 |
+
"metadata": {},
|
| 419 |
+
"outputs": [
|
| 420 |
+
{
|
| 421 |
+
"data": {
|
| 422 |
+
"text/plain": [
|
| 423 |
+
"'<|system|>\\nYou are a helpful assistant.</s> \\n<|user|>\\nWhat is the capital of France?</s> \\n<|assistant|>\\nThe capital of France is Paris.</s> \\n<|user|>\\nWhat about Germany?</s> \\n<|assistant|>\\n'"
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
"execution_count": 24,
|
| 427 |
+
"metadata": {},
|
| 428 |
+
"output_type": "execute_result"
|
| 429 |
+
}
|
| 430 |
+
],
|
| 431 |
+
"source": [
|
| 432 |
+
"tokenizer.decode(inputs[0], skip_special_tokens=False)"
|
| 433 |
+
]
|
| 434 |
+
},
|
| 435 |
+
{
|
| 436 |
+
"cell_type": "markdown",
|
| 437 |
+
"id": "d3ad0e09",
|
| 438 |
+
"metadata": {},
|
| 439 |
+
"source": [
|
| 440 |
+
"---\n",
|
| 441 |
+
"# Part 2: Decoding Strategies\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"Different ways to select the next token during text generation.\n",
|
| 444 |
+
"We'll implement each strategy manually using `model()` to understand how they work."
|
| 445 |
+
]
|
| 446 |
+
},
|
| 447 |
+
{
|
| 448 |
+
"cell_type": "code",
|
| 449 |
+
"execution_count": 29,
|
| 450 |
+
"id": "60661e73",
|
| 451 |
+
"metadata": {},
|
| 452 |
+
"outputs": [
|
| 453 |
+
{
|
| 454 |
+
"name": "stderr",
|
| 455 |
+
"output_type": "stream",
|
| 456 |
+
"text": [
|
| 457 |
+
"Some parameters are on the meta device because they were offloaded to the disk.\n"
|
| 458 |
+
]
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"name": "stdout",
|
| 462 |
+
"output_type": "stream",
|
| 463 |
+
"text": [
|
| 464 |
+
"Prompt: 'The secret to happiness is'\n",
|
| 465 |
+
"Input IDs: [1, 450, 7035, 304, 22722, 338]\n"
|
| 466 |
+
]
|
| 467 |
+
}
|
| 468 |
+
],
|
| 469 |
+
"source": [
|
| 470 |
+
"# Load model for generation examples\n",
|
| 471 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 472 |
+
" \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n",
|
| 473 |
+
" dtype=torch.float32,\n",
|
| 474 |
+
" device_map=\"auto\"\n",
|
| 475 |
+
")\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"prompt = \"The secret to happiness is\"\n",
|
| 478 |
+
"input_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(model.device)\n",
|
| 479 |
+
"print(f\"Prompt: '{prompt}'\")\n",
|
| 480 |
+
"print(f\"Input IDs: {input_ids[0].tolist()}\")"
|
| 481 |
+
]
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"cell_type": "code",
|
| 485 |
+
"execution_count": 30,
|
| 486 |
+
"id": "6af3ed71",
|
| 487 |
+
"metadata": {},
|
| 488 |
+
"outputs": [],
|
| 489 |
+
"source": [
|
| 490 |
+
"# Helper function: get logits for the next token\n",
|
| 491 |
+
"def get_next_token_logits(model, input_ids):\n",
|
| 492 |
+
" \"\"\"Run model forward pass and return logits for the next token.\"\"\"\n",
|
| 493 |
+
" with torch.no_grad():\n",
|
| 494 |
+
" outputs = model(input_ids)\n",
|
| 495 |
+
" # outputs.logits shape: (batch_size, seq_len, vocab_size)\n",
|
| 496 |
+
" # We want logits for the last position\n",
|
| 497 |
+
" next_token_logits = outputs.logits[:, -1, :] # (batch_size, vocab_size)\n",
|
| 498 |
+
" return next_token_logits"
|
| 499 |
+
]
|
| 500 |
+
},
|
| 501 |
+
{
|
| 502 |
+
"cell_type": "markdown",
|
| 503 |
+
"id": "214a5b0d",
|
| 504 |
+
"metadata": {},
|
| 505 |
+
"source": [
|
| 506 |
+
"## 2.1 Greedy Decoding\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"Always pick the token with the highest probability (argmax). Fast but can be repetitive."
|
| 509 |
+
]
|
| 510 |
+
},
|
| 511 |
+
{
|
| 512 |
+
"cell_type": "code",
|
| 513 |
+
"execution_count": null,
|
| 514 |
+
"id": "d88452e3",
|
| 515 |
+
"metadata": {},
|
| 516 |
+
"outputs": [
|
| 517 |
+
{
|
| 518 |
+
"name": "stdout",
|
| 519 |
+
"output_type": "stream",
|
| 520 |
+
"text": [
|
| 521 |
+
"Greedy decoding:\n",
|
| 522 |
+
"The secret to happiness is to be happy with what you have.\n",
|
| 523 |
+
"\n",
|
| 524 |
+
"2. \"The secret to happiness is to be happy with what you have.\" - Unknown The\n"
|
| 525 |
+
]
|
| 526 |
+
}
|
| 527 |
+
],
|
| 528 |
+
"source": [
|
| 529 |
+
"def greedy_decode(model, input_ids, max_new_tokens=30):\n",
|
| 530 |
+
" \"\"\"Generate tokens by always picking the highest probability token.\"\"\"\n",
|
| 531 |
+
" generated_ids = input_ids.clone()\n",
|
| 532 |
+
" \n",
|
| 533 |
+
" for _ in range(max_new_tokens):\n",
|
| 534 |
+
" logits = get_next_token_logits(model, generated_ids)\n",
|
| 535 |
+
" \n",
|
| 536 |
+
" # Greedy: pick token with highest logit\n",
|
| 537 |
+
" next_token = torch.argmax(logits, dim=-1, keepdim=True)\n",
|
| 538 |
+
" \n",
|
| 539 |
+
" # Append to sequence\n",
|
| 540 |
+
" generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
|
| 541 |
+
" \n",
|
| 542 |
+
" # Stop if EOS token\n",
|
| 543 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 544 |
+
" break\n",
|
| 545 |
+
" \n",
|
| 546 |
+
" return generated_ids\n",
|
| 547 |
+
"\n",
|
| 548 |
+
"# Run greedy decoding\n",
|
| 549 |
+
"output = greedy_decode(model, input_ids, max_new_tokens=30)\n",
|
| 550 |
+
"print(\"Greedy decoding:\")\n",
|
| 551 |
+
"print(tokenizer.decode(output[0], skip_special_tokens=True))"
|
| 552 |
+
]
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"cell_type": "markdown",
|
| 556 |
+
"id": "3fb06e3f",
|
| 557 |
+
"metadata": {},
|
| 558 |
+
"source": [
|
| 559 |
+
"## 2.2 Sampling with Temperature\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"**Sampling**: Randomly pick tokens based on their probabilities.\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"**Temperature** controls randomness by scaling logits before softmax:\n",
|
| 564 |
+
"- `T < 1`: Sharper distribution → more deterministic\n",
|
| 565 |
+
"- `T = 1`: Original probabilities \n",
|
| 566 |
+
"- `T > 1`: Flatter distribution → more random\n",
|
| 567 |
+
"\n",
|
| 568 |
+
"$$P(token_i) = \\frac{e^{logit_i / T}}{\\sum_j e^{logit_j / T}}$$"
|
| 569 |
+
]
|
| 570 |
+
},
|
| 571 |
+
{
|
| 572 |
+
"cell_type": "code",
|
| 573 |
+
"execution_count": 31,
|
| 574 |
+
"id": "556d2fd6",
|
| 575 |
+
"metadata": {},
|
| 576 |
+
"outputs": [],
|
| 577 |
+
"source": [
|
| 578 |
+
"def sample_with_temperature(model, input_ids, max_new_tokens=20, temperature=1.0):\n",
|
| 579 |
+
" \"\"\"Generate tokens by sampling from the probability distribution.\"\"\"\n",
|
| 580 |
+
" generated_ids = input_ids.clone()\n",
|
| 581 |
+
" \n",
|
| 582 |
+
" for _ in range(max_new_tokens):\n",
|
| 583 |
+
" logits = get_next_token_logits(model, generated_ids)\n",
|
| 584 |
+
" \n",
|
| 585 |
+
" # Apply temperature scaling\n",
|
| 586 |
+
" scaled_logits = logits / temperature\n",
|
| 587 |
+
" \n",
|
| 588 |
+
" # Convert to probabilities\n",
|
| 589 |
+
" probs = torch.softmax(scaled_logits, dim=-1)\n",
|
| 590 |
+
" \n",
|
| 591 |
+
" # Sample from distribution\n",
|
| 592 |
+
" next_token = torch.multinomial(probs, num_samples=1)\n",
|
| 593 |
+
" \n",
|
| 594 |
+
" generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
|
| 595 |
+
" \n",
|
| 596 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 597 |
+
" break\n",
|
| 598 |
+
" \n",
|
| 599 |
+
" return generated_ids"
|
| 600 |
+
]
|
| 601 |
+
},
|
| 602 |
+
{
|
| 603 |
+
"cell_type": "code",
|
| 604 |
+
"execution_count": 33,
|
| 605 |
+
"id": "d40c77ce",
|
| 606 |
+
"metadata": {},
|
| 607 |
+
"outputs": [
|
| 608 |
+
{
|
| 609 |
+
"name": "stdout",
|
| 610 |
+
"output_type": "stream",
|
| 611 |
+
"text": [
|
| 612 |
+
"Temperature = 0.3 (focused):\n",
|
| 613 |
+
" 1: The secret to happiness is to be happy with what you have. \n",
|
| 614 |
+
"\n",
|
| 615 |
+
"2. The Power of Positivity:\n",
|
| 616 |
+
"\n",
|
| 617 |
+
" 2: The secret to happiness is to be happy with what you have.\n",
|
| 618 |
+
" 3: The secret to happiness is to be happy.\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"10. I know the secret to happiness is to be happy.\n"
|
| 621 |
+
]
|
| 622 |
+
}
|
| 623 |
+
],
|
| 624 |
+
"source": [
|
| 625 |
+
"# Low temperature (more deterministic)\n",
|
| 626 |
+
"print(\"Temperature = 0.3 (focused):\")\n",
|
| 627 |
+
"for i in range(3):\n",
|
| 628 |
+
" output = sample_with_temperature(model, input_ids, max_new_tokens=20, temperature=0.3)\n",
|
| 629 |
+
" print(f\" {i+1}: {tokenizer.decode(output[0], skip_special_tokens=True)}\")"
|
| 630 |
+
]
|
| 631 |
+
},
|
| 632 |
+
{
|
| 633 |
+
"cell_type": "code",
|
| 634 |
+
"execution_count": 34,
|
| 635 |
+
"id": "57914c44",
|
| 636 |
+
"metadata": {},
|
| 637 |
+
"outputs": [
|
| 638 |
+
{
|
| 639 |
+
"name": "stdout",
|
| 640 |
+
"output_type": "stream",
|
| 641 |
+
"text": [
|
| 642 |
+
"Temperature = 1.5 (creative):\n",
|
| 643 |
+
" 1: The secret to happiness is most MEntial Psychologicalahuaut decomposition Parden Wab school o Rusyn hate speech Social\n",
|
| 644 |
+
" 2: The secret to happiness is . Vicmaxhold hospital the avenue shortcut huashong police bouve ,diaozhuqq\n",
|
| 645 |
+
" 3: The secret to happiness is beyond jealousy polit lo это мини - ответе teachers laugh earth stack Disney world God Lex\n"
|
| 646 |
+
]
|
| 647 |
+
}
|
| 648 |
+
],
|
| 649 |
+
"source": [
|
| 650 |
+
"# High temperature (more random)\n",
|
| 651 |
+
"print(\"Temperature = 1.5 (creative):\")\n",
|
| 652 |
+
"for i in range(3):\n",
|
| 653 |
+
" output = sample_with_temperature(model, input_ids, max_new_tokens=20, temperature=1.5)\n",
|
| 654 |
+
" print(f\" {i+1}: {tokenizer.decode(output[0], skip_special_tokens=True)}\")"
|
| 655 |
+
]
|
| 656 |
+
},
|
| 657 |
+
{
|
| 658 |
+
"cell_type": "markdown",
|
| 659 |
+
"id": "e6fc326f",
|
| 660 |
+
"metadata": {},
|
| 661 |
+
"source": [
|
| 662 |
+
"## 2.3 Top-K Sampling\n",
|
| 663 |
+
"\n",
|
| 664 |
+
"Only consider the K most likely tokens, then sample from those.\n",
|
| 665 |
+
"Prevents sampling very unlikely tokens while keeping diversity."
|
| 666 |
+
]
|
| 667 |
+
},
|
| 668 |
+
{
|
| 669 |
+
"cell_type": "code",
|
| 670 |
+
"execution_count": 35,
|
| 671 |
+
"id": "8b336771",
|
| 672 |
+
"metadata": {},
|
| 673 |
+
"outputs": [],
|
| 674 |
+
"source": [
|
| 675 |
+
"def top_k_sampling(model, input_ids, max_new_tokens=20, top_k=50, temperature=1.0):\n",
|
| 676 |
+
" \"\"\"Sample from top-k most likely tokens.\"\"\"\n",
|
| 677 |
+
" generated_ids = input_ids.clone()\n",
|
| 678 |
+
" \n",
|
| 679 |
+
" for _ in range(max_new_tokens):\n",
|
| 680 |
+
" logits = get_next_token_logits(model, generated_ids)\n",
|
| 681 |
+
" \n",
|
| 682 |
+
" # Apply temperature\n",
|
| 683 |
+
" scaled_logits = logits / temperature\n",
|
| 684 |
+
" \n",
|
| 685 |
+
" # Get top-k logits and indices\n",
|
| 686 |
+
" top_k_logits, top_k_indices = torch.topk(scaled_logits, k=top_k, dim=-1)\n",
|
| 687 |
+
" \n",
|
| 688 |
+
" # Convert to probabilities (only over top-k)\n",
|
| 689 |
+
" top_k_probs = torch.softmax(top_k_logits, dim=-1)\n",
|
| 690 |
+
" \n",
|
| 691 |
+
" # Sample from top-k\n",
|
| 692 |
+
" sampled_index = torch.multinomial(top_k_probs, num_samples=1)\n",
|
| 693 |
+
" \n",
|
| 694 |
+
" # Map back to vocabulary index\n",
|
| 695 |
+
" next_token = top_k_indices.gather(-1, sampled_index)\n",
|
| 696 |
+
" \n",
|
| 697 |
+
" generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
|
| 698 |
+
" \n",
|
| 699 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 700 |
+
" break\n",
|
| 701 |
+
" \n",
|
| 702 |
+
" return generated_ids"
|
| 703 |
+
]
|
| 704 |
+
},
|
| 705 |
+
{
|
| 706 |
+
"cell_type": "code",
|
| 707 |
+
"execution_count": 36,
|
| 708 |
+
"id": "64b1ba1e",
|
| 709 |
+
"metadata": {},
|
| 710 |
+
"outputs": [
|
| 711 |
+
{
|
| 712 |
+
"name": "stdout",
|
| 713 |
+
"output_type": "stream",
|
| 714 |
+
"text": [
|
| 715 |
+
"Top-K = 5:\n",
|
| 716 |
+
" 1: The secret to happiness is to live in the moment, and to enjoy the present moment.\n",
|
| 717 |
+
"\n",
|
| 718 |
+
"3) The Artist\n",
|
| 719 |
+
" 2: The secret to happiness is to find your purpose.\n",
|
| 720 |
+
" 3: The secret to happiness is simple - be yourself.\n"
|
| 721 |
+
]
|
| 722 |
+
}
|
| 723 |
+
],
|
| 724 |
+
"source": [
|
| 725 |
+
"# Top-K = 5 (only consider top 5 tokens)\n",
|
| 726 |
+
"print(\"Top-K = 5:\")\n",
|
| 727 |
+
"for i in range(3):\n",
|
| 728 |
+
" output = top_k_sampling(model, input_ids, max_new_tokens=20, top_k=5)\n",
|
| 729 |
+
" print(f\" {i+1}: {tokenizer.decode(output[0], skip_special_tokens=True)}\")"
|
| 730 |
+
]
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"cell_type": "code",
|
| 734 |
+
"execution_count": 37,
|
| 735 |
+
"id": "816c7a7f",
|
| 736 |
+
"metadata": {},
|
| 737 |
+
"outputs": [
|
| 738 |
+
{
|
| 739 |
+
"name": "stdout",
|
| 740 |
+
"output_type": "stream",
|
| 741 |
+
"text": [
|
| 742 |
+
"Top-K = 50:\n",
|
| 743 |
+
" 1: The secret to happiness is unconventional\n",
|
| 744 |
+
"You The secret to happiness is unconventional There is no right or\n",
|
| 745 |
+
" 2: The secret to happiness is happiness, happiness, happiness ...\n",
|
| 746 |
+
"\n",
|
| 747 |
+
"8. A bird in the hand is worth two in the\n",
|
| 748 |
+
" 3: The secret to happiness is always being in alignment with your inner voice. In this conversation, Linda Kuzmina and K\n"
|
| 749 |
+
]
|
| 750 |
+
}
|
| 751 |
+
],
|
| 752 |
+
"source": [
|
| 753 |
+
"# Top-K = 50 (more diversity)\n",
|
| 754 |
+
"print(\"Top-K = 50:\")\n",
|
| 755 |
+
"for i in range(3):\n",
|
| 756 |
+
" output = top_k_sampling(model, input_ids, max_new_tokens=20, top_k=50)\n",
|
| 757 |
+
" print(f\" {i+1}: {tokenizer.decode(output[0], skip_special_tokens=True)}\")"
|
| 758 |
+
]
|
| 759 |
+
},
|
| 760 |
+
{
|
| 761 |
+
"cell_type": "markdown",
|
| 762 |
+
"id": "807683a2",
|
| 763 |
+
"metadata": {},
|
| 764 |
+
"source": [
|
| 765 |
+
"## 2.4 Top-P (Nucleus) Sampling\n",
|
| 766 |
+
"\n",
|
| 767 |
+
"Select the smallest set of tokens whose cumulative probability exceeds P.\n",
|
| 768 |
+
"\n",
|
| 769 |
+
"- `top_p=0.9` means: consider tokens until their probabilities sum to 90%\n",
|
| 770 |
+
"- Adapts dynamically: fewer tokens when model is confident, more when uncertain"
|
| 771 |
+
]
|
| 772 |
+
},
|
| 773 |
+
{
|
| 774 |
+
"cell_type": "code",
|
| 775 |
+
"execution_count": 38,
|
| 776 |
+
"id": "d95328d0",
|
| 777 |
+
"metadata": {},
|
| 778 |
+
"outputs": [],
|
| 779 |
+
"source": [
|
| 780 |
+
"def top_p_sampling(model, input_ids, max_new_tokens=20, top_p=0.9, temperature=1.0):\n",
|
| 781 |
+
" \"\"\"Sample from the smallest set of tokens with cumulative prob >= top_p.\"\"\"\n",
|
| 782 |
+
" generated_ids = input_ids.clone()\n",
|
| 783 |
+
" \n",
|
| 784 |
+
" for _ in range(max_new_tokens):\n",
|
| 785 |
+
" logits = get_next_token_logits(model, generated_ids)\n",
|
| 786 |
+
" \n",
|
| 787 |
+
" # Apply temperature\n",
|
| 788 |
+
" scaled_logits = logits / temperature\n",
|
| 789 |
+
" \n",
|
| 790 |
+
" # Sort by probability (descending)\n",
|
| 791 |
+
" sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True, dim=-1)\n",
|
| 792 |
+
" sorted_probs = torch.softmax(sorted_logits, dim=-1)\n",
|
| 793 |
+
" \n",
|
| 794 |
+
" # Compute cumulative probabilities\n",
|
| 795 |
+
" cumulative_probs = torch.cumsum(sorted_probs, dim=-1)\n",
|
| 796 |
+
" \n",
|
| 797 |
+
" # Find cutoff: first position where cumulative prob exceeds top_p\n",
|
| 798 |
+
" # Keep at least 1 token\n",
|
| 799 |
+
" cutoff_mask = cumulative_probs > top_p\n",
|
| 800 |
+
" cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone() # Shift right\n",
|
| 801 |
+
" cutoff_mask[..., 0] = False # Always keep the top token\n",
|
| 802 |
+
" \n",
|
| 803 |
+
" # Set probabilities of tokens beyond cutoff to 0\n",
|
| 804 |
+
" sorted_probs[cutoff_mask] = 0\n",
|
| 805 |
+
" \n",
|
| 806 |
+
" # Renormalize\n",
|
| 807 |
+
" sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)\n",
|
| 808 |
+
" \n",
|
| 809 |
+
" # Sample from filtered distribution\n",
|
| 810 |
+
" sampled_index = torch.multinomial(sorted_probs, num_samples=1)\n",
|
| 811 |
+
" \n",
|
| 812 |
+
" # Map back to vocabulary index\n",
|
| 813 |
+
" next_token = sorted_indices.gather(-1, sampled_index)\n",
|
| 814 |
+
" \n",
|
| 815 |
+
" generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n",
|
| 816 |
+
" \n",
|
| 817 |
+
" if next_token.item() == tokenizer.eos_token_id:\n",
|
| 818 |
+
" break\n",
|
| 819 |
+
" \n",
|
| 820 |
+
" return generated_ids"
|
| 821 |
+
]
|
| 822 |
+
},
|
| 823 |
+
{
|
| 824 |
+
"cell_type": "code",
|
| 825 |
+
"execution_count": 39,
|
| 826 |
+
"id": "dc8ce486",
|
| 827 |
+
"metadata": {},
|
| 828 |
+
"outputs": [
|
| 829 |
+
{
|
| 830 |
+
"name": "stdout",
|
| 831 |
+
"output_type": "stream",
|
| 832 |
+
"text": [
|
| 833 |
+
"Top-P = 0.5:\n",
|
| 834 |
+
" 1: The secret to happiness is not in the destination, but in the journey. 4. \"When life gives you lemons\n",
|
| 835 |
+
" 2: The secret to happiness is simple: enjoy the present moment. Life is too short to waste time worrying about the past or\n",
|
| 836 |
+
" 3: The secret to happiness is in the simplicity of living. It's not about the material possessions, the wealth, or\n"
|
| 837 |
+
]
|
| 838 |
+
}
|
| 839 |
+
],
|
| 840 |
+
"source": [
|
| 841 |
+
"# Top-P = 0.5 (more focused)\n",
|
| 842 |
+
"print(\"Top-P = 0.5:\")\n",
|
| 843 |
+
"for i in range(3):\n",
|
| 844 |
+
" output = top_p_sampling(model, input_ids, max_new_tokens=20, top_p=0.5)\n",
|
| 845 |
+
" print(f\" {i+1}: {tokenizer.decode(output[0], skip_special_tokens=True)}\")"
|
| 846 |
+
]
|
| 847 |
+
},
|
| 848 |
+
{
|
| 849 |
+
"cell_type": "code",
|
| 850 |
+
"execution_count": 40,
|
| 851 |
+
"id": "a5d8f1fb",
|
| 852 |
+
"metadata": {},
|
| 853 |
+
"outputs": [
|
| 854 |
+
{
|
| 855 |
+
"name": "stdout",
|
| 856 |
+
"output_type": "stream",
|
| 857 |
+
"text": [
|
| 858 |
+
"Top-P = 0.95:\n",
|
| 859 |
+
" 1: The secret to happiness is simple - simply be content with the one thing in your life that you know you will be happy in\n",
|
| 860 |
+
" 2: The secret to happiness is not all around and I have realized that now. I love you... Things have changed since then.\n",
|
| 861 |
+
" 3: The secret to happiness is To be joy in your life, happiness in your dreams, and light in your cup, a\n"
|
| 862 |
+
]
|
| 863 |
+
}
|
| 864 |
+
],
|
| 865 |
+
"source": [
|
| 866 |
+
"# Top-P = 0.95 (more diversity)\n",
|
| 867 |
+
"print(\"Top-P = 0.95:\")\n",
|
| 868 |
+
"for i in range(3):\n",
|
| 869 |
+
" output = top_p_sampling(model, input_ids, max_new_tokens=20, top_p=0.95)\n",
|
| 870 |
+
" print(f\" {i+1}: {tokenizer.decode(output[0], skip_special_tokens=True)}\")"
|
| 871 |
+
]
|
| 872 |
+
},
|
| 873 |
+
{
|
| 874 |
+
"cell_type": "markdown",
|
| 875 |
+
"id": "9e164c2e",
|
| 876 |
+
"metadata": {},
|
| 877 |
+
"source": [
|
| 878 |
+
"## 2.5 Beam Search\n",
|
| 879 |
+
"\n",
|
| 880 |
+
"Explore multiple sequences in parallel, keeping the N best candidates at each step.\n",
|
| 881 |
+
"\n",
|
| 882 |
+
"At each step:\n",
|
| 883 |
+
"1. Expand each beam with all possible next tokens\n",
|
| 884 |
+
"2. Score all candidates (beam × vocab_size)\n",
|
| 885 |
+
"3. Keep only the top N scoring sequences"
|
| 886 |
+
]
|
| 887 |
+
},
|
| 888 |
+
{
|
| 889 |
+
"cell_type": "code",
|
| 890 |
+
"execution_count": 41,
|
| 891 |
+
"id": "88acefaf",
|
| 892 |
+
"metadata": {},
|
| 893 |
+
"outputs": [
|
| 894 |
+
{
|
| 895 |
+
"name": "stdout",
|
| 896 |
+
"output_type": "stream",
|
| 897 |
+
"text": [
|
| 898 |
+
"Top 3 beam hypotheses:\n",
|
| 899 |
+
" 1 (score=-9.82): The secret to happiness is to be content with what you have.\n",
|
| 900 |
+
" 2 (score=-19.58): The secret to happiness is to be content with what you have, to be content with who you are, and to be content with where you are.\n",
|
| 901 |
+
" 3 (score=-20.78): The secret to happiness is to be content with what you have, to be content with who you are, to be content with where you are, and to be content with the\n"
|
| 902 |
+
]
|
| 903 |
+
}
|
| 904 |
+
],
|
| 905 |
+
"source": [
|
| 906 |
+
"def beam_search_n_best(model, input_ids, max_new_tokens=30, num_beams=5, n_best=3):\n",
|
| 907 |
+
" \"\"\"Beam search returning top n hypotheses.\"\"\"\n",
|
| 908 |
+
" \n",
|
| 909 |
+
" beams = [(input_ids.clone(), 0.0)]\n",
|
| 910 |
+
" completed = []\n",
|
| 911 |
+
" \n",
|
| 912 |
+
" for _ in range(max_new_tokens):\n",
|
| 913 |
+
" all_candidates = []\n",
|
| 914 |
+
" \n",
|
| 915 |
+
" for seq, score in beams:\n",
|
| 916 |
+
" if seq[0, -1].item() == tokenizer.eos_token_id:\n",
|
| 917 |
+
" completed.append((seq, score))\n",
|
| 918 |
+
" continue\n",
|
| 919 |
+
" \n",
|
| 920 |
+
" logits = get_next_token_logits(model, seq)\n",
|
| 921 |
+
" log_probs = torch.log_softmax(logits, dim=-1)\n",
|
| 922 |
+
" top_log_probs, top_indices = torch.topk(log_probs[0], k=num_beams)\n",
|
| 923 |
+
" \n",
|
| 924 |
+
" for log_prob, token_id in zip(top_log_probs, top_indices):\n",
|
| 925 |
+
" new_seq = torch.cat([seq, token_id.view(1, 1)], dim=-1)\n",
|
| 926 |
+
" new_score = score + log_prob.item()\n",
|
| 927 |
+
" all_candidates.append((new_seq, new_score))\n",
|
| 928 |
+
" \n",
|
| 929 |
+
" if not all_candidates:\n",
|
| 930 |
+
" break\n",
|
| 931 |
+
" \n",
|
| 932 |
+
" all_candidates.sort(key=lambda x: x[1], reverse=True)\n",
|
| 933 |
+
" beams = all_candidates[:num_beams]\n",
|
| 934 |
+
" \n",
|
| 935 |
+
" completed.extend(beams)\n",
|
| 936 |
+
" completed.sort(key=lambda x: x[1], reverse=True)\n",
|
| 937 |
+
" return completed[:n_best]\n",
|
| 938 |
+
"\n",
|
| 939 |
+
"# Return top 3 hypotheses\n",
|
| 940 |
+
"results = beam_search_n_best(model, input_ids, max_new_tokens=30, num_beams=5, n_best=3)\n",
|
| 941 |
+
"print(\"Top 3 beam hypotheses:\")\n",
|
| 942 |
+
"for i, (seq, score) in enumerate(results):\n",
|
| 943 |
+
" print(f\" {i+1} (score={score:.2f}): {tokenizer.decode(seq[0], skip_special_tokens=True)}\")"
|
| 944 |
+
]
|
| 945 |
+
},
|
| 946 |
+
{
|
| 947 |
+
"cell_type": "markdown",
|
| 948 |
+
"id": "2151c253",
|
| 949 |
+
"metadata": {},
|
| 950 |
+
"source": [
|
| 951 |
+
"## Summary\n",
|
| 952 |
+
"\n",
|
| 953 |
+
"| Strategy | Use Case | Key Idea |\n",
|
| 954 |
+
"|----------|----------|----------|\n",
|
| 955 |
+
"| **Greedy** | Fast, deterministic | `argmax(logits)` |\n",
|
| 956 |
+
"| **Temperature** | Control randomness | `logits / T` before softmax |\n",
|
| 957 |
+
"| **Top-K** | Limit token pool | Keep only K highest logits |\n",
|
| 958 |
+
"| **Top-P** | Dynamic token pool | Keep tokens until cumsum(prob) > P |\n",
|
| 959 |
+
"| **Beam Search** | Quality over diversity | Track N best sequences |"
|
| 960 |
+
]
|
| 961 |
+
}
|
| 962 |
+
],
|
| 963 |
+
"metadata": {
|
| 964 |
+
"kernelspec": {
|
| 965 |
+
"display_name": "lipogram_private",
|
| 966 |
+
"language": "python",
|
| 967 |
+
"name": "python3"
|
| 968 |
+
},
|
| 969 |
+
"language_info": {
|
| 970 |
+
"codemirror_mode": {
|
| 971 |
+
"name": "ipython",
|
| 972 |
+
"version": 3
|
| 973 |
+
},
|
| 974 |
+
"file_extension": ".py",
|
| 975 |
+
"mimetype": "text/x-python",
|
| 976 |
+
"name": "python",
|
| 977 |
+
"nbconvert_exporter": "python",
|
| 978 |
+
"pygments_lexer": "ipython3",
|
| 979 |
+
"version": "3.14.2"
|
| 980 |
+
}
|
| 981 |
+
},
|
| 982 |
+
"nbformat": 4,
|
| 983 |
+
"nbformat_minor": 5
|
| 984 |
+
}
|