File size: 9,511 Bytes
1d7752e
116756e
 
 
 
 
 
 
 
 
 
 
 
1d7752e
116756e
 
 
 
 
 
 
 
 
 
1d7752e
116756e
1d7752e
116756e
1d7752e
4d8bbd9
 
6538c21
 
116756e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d7752e
116756e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d8bbd9
116756e
4d8bbd9
 
116756e
 
 
 
 
 
 
1d7752e
116756e
 
 
 
 
 
1d7752e
 
 
 
 
 
 
 
 
 
116756e
1d7752e
 
 
 
 
 
 
 
 
 
 
 
116756e
1d7752e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116756e
 
1d7752e
 
6538c21
 
1d7752e
116756e
1d7752e
 
116756e
 
 
 
 
 
 
 
 
1d7752e
 
 
 
 
116756e
1d7752e
 
 
 
 
 
 
116756e
 
 
 
 
 
 
1d7752e
b63c610
1d7752e
4d8bbd9
1d7752e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from typing import List, Tuple
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- EXERCISE 1: La disparition (No 'e' or 'E) ---
class LaDisparition:
    """
    Generate text without ever using the letter 'e' or 'E'.
    For this, you must use model() directly: model(input_ids) yields logits.
    You need to manually adjust the logits to forbid tokens containing 'e' or 'E'.
    REQUIREMENT: Do NOT use model.generate().
    """
    def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, debug: bool = False):
        self.model = model
        self.tokenizer = tokenizer
        self.debug = debug
        # Pre-calculate forbidden token IDs (tokens that decode to contain 'e' or 'E' or non-ASCII)
        # Check decoded output, not just the vocab string representation
        self.forbidden_token_ids = set()
        vocab = self.tokenizer.get_vocab()
        for token_id in range(len(vocab)):
            # Decode the token to see what it actually produces
            decoded = self.tokenizer.decode([token_id])
            # Forbid if contains 'e'/'E' or contains non-ASCII (which might hide 'e')
            if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded):
                self.forbidden_token_ids.add(token_id)        

    def __call__(self, prompt, max_tokens=20, beam_width=5):
        # Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
        message = [{"role": "user", "content": prompt}]
        encoded = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt")
        input_ids = (encoded if isinstance(encoded, torch.Tensor) else encoded["input_ids"]).to(self.model.device)
        prompt_len = input_ids.shape[1]
        
        # Beam search: maintain multiple hypotheses
        # Each hypothesis: (sequence, log_prob)
        beams: List[Tuple[List[int], float]] = [(input_ids[0].tolist(), 0.0)]

        for step in range(max_tokens):
            candidates = []
            
            for seq, log_prob in beams:
                input_tensor = torch.tensor([seq], device=self.model.device)
                
                # Get logits from model
                with torch.no_grad():
                    outputs = self.model(input_tensor)
                    logits = outputs.logits[0, -1, :].clone()
                
                # Create mask for forbidden tokens
                forbidden_mask = torch.zeros_like(logits, dtype=torch.bool)
                forbidden_mask[list(self.forbidden_token_ids)] = True
                                
                # Convert to log probabilities
                log_probs = F.log_softmax(logits, dim=-1)
                
                # Ensure forbidden tokens stay at -inf in log space
                log_probs[forbidden_mask] = -float('inf')
                
                # Get top-k tokens for this beam, excluding -inf values
                top_k = min(beam_width, (~forbidden_mask).sum().item())
                if top_k > 0:
                    top_log_probs, top_indices = torch.topk(log_probs, top_k)
                else:
                    # No valid tokens available, skip this beam
                    continue
                
                for token_id, token_log_prob in zip(top_indices.tolist(), top_log_probs.tolist()):
                    if token_id == self.tokenizer.eos_token_id:
                        candidates.append((seq, log_prob + token_log_prob))
                    else:
                        candidates.append((seq + [token_id], log_prob + token_log_prob))
            
            # Keep top beam_width candidates by log probability
            candidates.sort(key=lambda x: x[1], reverse=True)
            beams = candidates[:beam_width]
            
            # Stop if all beams ended
            if all(seq[-1] == self.tokenizer.eos_token_id for seq, _ in beams):
                break
        
        # Debug: print all beams
        if self.debug:
            print(f"\n[DEBUG Ex1] Total beams: {len(beams)}")
            for i, (seq, log_prob) in enumerate(beams):
                decoded = self.tokenizer.decode(seq, skip_special_tokens=True)
                print(f"  Beam {i}: log_prob={log_prob:.4f} | {decoded}")
        
        # Return the best hypothesis (only the generated part)
        best_seq = beams[0][0]
        generated_tokens = best_seq[prompt_len:]
        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()


# --- EXERCISE 2: The Toulouse Sequence ---
class ToulouseSequence:
    """
    Generate text without ever using the word 'Toulouse'.
    For this, you must use model() directly: model(input_ids) yields logits.
    We mask out all tokens that if added would lead to a prefix of "Toulouse" of length at least 4.
    REQUIREMENT: Do NOT use model.generate().
    """
    def __init__(self, model, tokenizer, debug=False):
        self.model = model
        self.tokenizer = tokenizer
        self.debug = debug
        self.forbidden_word = "Toulouse"
        self.min_prefix_len = 4

    def _get_current_word_prefix(self, decoded_sequence: str) -> str:
        """Find the suffix since the last non-alphabetical character."""
        last_separator_idx = -1
        for i in range(len(decoded_sequence) - 1, -1, -1):
            if not decoded_sequence[i].isalpha():
                last_separator_idx = i
                break
        
        if last_separator_idx != -1:
            return decoded_sequence[last_separator_idx + 1:]
        else:
            return decoded_sequence

    def _get_forbidden_mask(self, seq: List[int]) -> torch.Tensor:
        """
        Create a mask for tokens that would create a forbidden prefix of 'Toulouse'.
        Returns a boolean tensor where True means the token should be forbidden.
        """
        vocab_size = len(self.tokenizer.get_vocab())
        forbidden_mask = torch.zeros(vocab_size, dtype=torch.bool, device=self.model.device)
        
        # Decode the current sequence to find the current word prefix
        decoded_sequence = self.tokenizer.decode(seq)
        current_word_prefix = self._get_current_word_prefix(decoded_sequence)
        
        # If the current word prefix is empty, we don't need to check anything yet
        if not current_word_prefix:
            return forbidden_mask
        
        # Get the token IDs for the current word prefix
        current_word_ids = self.tokenizer.encode(current_word_prefix, add_special_tokens=False)
        
        # Iterate over all possible next tokens
        for token_id in range(vocab_size):
            # Create a hypothetical next word by adding the candidate token
            hypothetical_word_ids = current_word_ids + [token_id]
            hypothetical_word = self.tokenizer.decode(hypothetical_word_ids)
            
            # Check if the hypothetical word is a forbidden prefix (case-insensitive)
            if len(hypothetical_word) >= self.min_prefix_len and \
               self.forbidden_word.lower().startswith(hypothetical_word.lower()):
                forbidden_mask[token_id] = True
        
        return forbidden_mask

    def __call__(self, prompt, max_tokens=20):
        # Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
        message = [{"role": "user", "content": prompt}]
        encoded = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt")
        inputs = (encoded if isinstance(encoded, torch.Tensor) else encoded["input_ids"]).to(self.model.device)
        prompt_length = inputs.shape[1]
        
        # Generate tokens one by one
        seq = inputs[0].tolist()

        for step in range(max_tokens):
            input_tensor = torch.tensor([seq], device=self.model.device)
            
            # Get logits from model
            with torch.no_grad():
                outputs = self.model(input_tensor)
                logits = outputs.logits[0, -1, :].clone()
            
            # Get forbidden mask based on current word prefix
            forbidden_mask = self._get_forbidden_mask(seq)
            
            # Apply the mask: set forbidden tokens to -inf
            logits[forbidden_mask] = float('-inf')
                        
            # Greedy decoding
            next_token = torch.argmax(logits).item()
            
            # Stop if EOS token
            if next_token == self.tokenizer.eos_token_id:
                break
                
            seq.append(next_token)
        
        # Extract only the generated tokens (skip the input prompt tokens)
        generated_tokens = seq[prompt_length:]
        generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        return generated_text.strip()
    
if __name__ == "__main__":
    # NOTE: This block is for testing only. The evaluation server provides model and tokenizer.
    # SETUP
    MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
    la_disparition_generator = LaDisparition(model, tokenizer)
    print("Ex 1 (No 'e'):", la_disparition_generator("Who are you?"))
    toulouse_sequence_generator = ToulouseSequence(model, tokenizer, debug=True)
    print("Ex 2 (No 'Toulouse'):", toulouse_sequence_generator("Where is Toulouse?"))