File size: 6,545 Bytes
1d7752e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d8bbd9
 
 
 
 
 
 
 
1d7752e
4d8bbd9
 
1d7752e
 
 
 
 
 
4d8bbd9
 
 
1d7752e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d8bbd9
1d7752e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
from transformers import LogitsProcessor, AutoModelForCausalLM, AutoTokenizer

# --- EXERCISE 1: La disparition (No 'e' or 'E) ---

# --- Logits Processor to forbid specific tokens ---
class ForbidTokensLogitsProcessor(LogitsProcessor):
    """Logits processor that sets forbidden token logits to -inf."""
    def __init__(self, forbidden_token_ids):
        self.forbidden_token_ids = list(forbidden_token_ids)

    def __call__(self, input_ids, scores):
        scores[:, self.forbidden_token_ids] = float('-inf')
        return scores

class LaDisparition:
    """Generate text without ever using the letter 'e' or 'E' using model.generate()."""
    
    def __init__(self, model, tokenizer, debug=False):
        self.model = model
        self.tokenizer = tokenizer
        self.debug = debug
        # Pre-calculate forbidden token IDs (tokens containing 'e', 'E', or non-ASCII)
        self.forbidden_token_ids = set()
        for token_id in range(len(tokenizer.get_vocab())):
            decoded = tokenizer.decode([token_id])
            if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded):
                self.forbidden_token_ids.add(token_id)

        self.processor = ForbidTokensLogitsProcessor(self.forbidden_token_ids)

    def __call__(self, prompt, max_tokens=30, beam_width=5):
        # Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
        message = [{"role": "user", "content": prompt}]
        inputs = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
        
        # Create an attention mask for the inputs
        attention_mask = torch.ones_like(inputs)
        prompt_length = inputs.shape[1]
        
        outputs = self.model.generate(
            inputs,
            attention_mask=attention_mask,
            max_new_tokens=max_tokens,
            num_beams=beam_width,
            logits_processor=[self.processor],
            do_sample=False
        )
        
        # Return only the generated part
        generated_tokens = outputs[0][prompt_length:]
        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()


# --- EXERCISE 2: The Toulouse Sequence ---

class ForbidToulousePrefixLogitsProcessor(LogitsProcessor):
    """
    When generating, we store the largest suffix since whitespace.
    We mask out all tokens that if added would lead to a prefix of "Toulouse" of length at least 4.
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.forbidden_word = "toulouse"
        self.min_prefix_len = 4

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        current_sequence_ids = input_ids[0]
        
        # Decode the current sequence to find the last word
        decoded_sequence = self.tokenizer.decode(current_sequence_ids)
        
        # Find the start of the last word (suffix since the last non-alphabetical character)
        last_separator_idx = -1
        for i in range(len(decoded_sequence) - 1, -1, -1):
            if not decoded_sequence[i].isalpha():
                last_separator_idx = i
                break
        
        if last_separator_idx != -1:
            current_word_prefix = decoded_sequence[last_separator_idx + 1:]
        else:
            current_word_prefix = decoded_sequence

        # If the current word prefix is empty, we don't need to check anything yet
        if not current_word_prefix:
            return scores
        
        # print(f"Current word prefix: '{current_word_prefix}'")

        # Get the token IDs for the current word prefix to avoid re-tokenizing the whole sequence
        current_word_ids = self.tokenizer.encode(current_word_prefix, add_special_tokens=False)

        # Iterate over all possible next tokens
        for token_id in range(scores.shape[1]):
            # Create a hypothetical next word by adding the candidate token
            hypothetical_word_ids = current_word_ids + [token_id]
            hypothetical_word = self.tokenizer.decode(hypothetical_word_ids)

            # Check if the hypothetical word is a forbidden prefix
            # We check against the lowercase version for case-insensitivity
            if len(hypothetical_word) >= self.min_prefix_len and \
               self.forbidden_word.startswith(hypothetical_word.lower()):
                scores[0, token_id] = float('-inf')
                # print(f"Forbidden prefix: '{hypothetical_word}'")
                
        return scores

class ToulouseSequence:
    """Generate text without ever using the word 'Toulouse' using model.generate()."""
    
    def __init__(self, model, tokenizer, debug=False):
        self.model = model
        self.tokenizer = tokenizer
        self.debug = debug
        # Use the new processor for the "Toulouse" prefix strategy
        self.processor = ForbidToulousePrefixLogitsProcessor(self.tokenizer)

    def __call__(self, prompt, max_tokens=100):
        # Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
        message = [{"role": "user", "content": prompt}]
        inputs = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
        
        # Create an attention mask for the inputs
        attention_mask = torch.ones_like(inputs)

        prompt_length = inputs.shape[1]
        
        outputs = self.model.generate(
            inputs,
            attention_mask=attention_mask,
            max_new_tokens=max_tokens,
            logits_processor=[self.processor],
            do_sample=False
        )
        
        # Return only the generated part
        generated_tokens = outputs[0][prompt_length:]
        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

if __name__ == "__main__":
    # NOTE: This block is for testing only. The evaluation server provides model and tokenizer.
    # SETUP
    MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)        
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.float32, device_map="auto")
    la_disparition_generator = LaDisparition(model, tokenizer)
    print("Ex 1 (No 'e'):", la_disparition_generator("Who are you?"))
    toulouse_sequence_generator = ToulouseSequence(model, tokenizer)
    print("Ex 2 (No 'Toulouse'):", toulouse_sequence_generator("Where is Toulouse?"))