lipogram_private / solution.py
nathanael-fijalkow's picture
Fix for Transformers v5
6538c21
from typing import List, Tuple
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- EXERCISE 1: La disparition (No 'e' or 'E) ---
class LaDisparition:
"""
Generate text without ever using the letter 'e' or 'E'.
For this, you must use model() directly: model(input_ids) yields logits.
You need to manually adjust the logits to forbid tokens containing 'e' or 'E'.
REQUIREMENT: Do NOT use model.generate().
"""
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, debug: bool = False):
self.model = model
self.tokenizer = tokenizer
self.debug = debug
# Pre-calculate forbidden token IDs (tokens that decode to contain 'e' or 'E' or non-ASCII)
# Check decoded output, not just the vocab string representation
self.forbidden_token_ids = set()
vocab = self.tokenizer.get_vocab()
for token_id in range(len(vocab)):
# Decode the token to see what it actually produces
decoded = self.tokenizer.decode([token_id])
# Forbid if contains 'e'/'E' or contains non-ASCII (which might hide 'e')
if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded):
self.forbidden_token_ids.add(token_id)
def __call__(self, prompt, max_tokens=20, beam_width=5):
# Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
message = [{"role": "user", "content": prompt}]
encoded = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt")
input_ids = (encoded if isinstance(encoded, torch.Tensor) else encoded["input_ids"]).to(self.model.device)
prompt_len = input_ids.shape[1]
# Beam search: maintain multiple hypotheses
# Each hypothesis: (sequence, log_prob)
beams: List[Tuple[List[int], float]] = [(input_ids[0].tolist(), 0.0)]
for step in range(max_tokens):
candidates = []
for seq, log_prob in beams:
input_tensor = torch.tensor([seq], device=self.model.device)
# Get logits from model
with torch.no_grad():
outputs = self.model(input_tensor)
logits = outputs.logits[0, -1, :].clone()
# Create mask for forbidden tokens
forbidden_mask = torch.zeros_like(logits, dtype=torch.bool)
forbidden_mask[list(self.forbidden_token_ids)] = True
# Convert to log probabilities
log_probs = F.log_softmax(logits, dim=-1)
# Ensure forbidden tokens stay at -inf in log space
log_probs[forbidden_mask] = -float('inf')
# Get top-k tokens for this beam, excluding -inf values
top_k = min(beam_width, (~forbidden_mask).sum().item())
if top_k > 0:
top_log_probs, top_indices = torch.topk(log_probs, top_k)
else:
# No valid tokens available, skip this beam
continue
for token_id, token_log_prob in zip(top_indices.tolist(), top_log_probs.tolist()):
if token_id == self.tokenizer.eos_token_id:
candidates.append((seq, log_prob + token_log_prob))
else:
candidates.append((seq + [token_id], log_prob + token_log_prob))
# Keep top beam_width candidates by log probability
candidates.sort(key=lambda x: x[1], reverse=True)
beams = candidates[:beam_width]
# Stop if all beams ended
if all(seq[-1] == self.tokenizer.eos_token_id for seq, _ in beams):
break
# Debug: print all beams
if self.debug:
print(f"\n[DEBUG Ex1] Total beams: {len(beams)}")
for i, (seq, log_prob) in enumerate(beams):
decoded = self.tokenizer.decode(seq, skip_special_tokens=True)
print(f" Beam {i}: log_prob={log_prob:.4f} | {decoded}")
# Return the best hypothesis (only the generated part)
best_seq = beams[0][0]
generated_tokens = best_seq[prompt_len:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# --- EXERCISE 2: The Toulouse Sequence ---
class ToulouseSequence:
"""
Generate text without ever using the word 'Toulouse'.
For this, you must use model() directly: model(input_ids) yields logits.
We mask out all tokens that if added would lead to a prefix of "Toulouse" of length at least 4.
REQUIREMENT: Do NOT use model.generate().
"""
def __init__(self, model, tokenizer, debug=False):
self.model = model
self.tokenizer = tokenizer
self.debug = debug
self.forbidden_word = "Toulouse"
self.min_prefix_len = 4
def _get_current_word_prefix(self, decoded_sequence: str) -> str:
"""Find the suffix since the last non-alphabetical character."""
last_separator_idx = -1
for i in range(len(decoded_sequence) - 1, -1, -1):
if not decoded_sequence[i].isalpha():
last_separator_idx = i
break
if last_separator_idx != -1:
return decoded_sequence[last_separator_idx + 1:]
else:
return decoded_sequence
def _get_forbidden_mask(self, seq: List[int]) -> torch.Tensor:
"""
Create a mask for tokens that would create a forbidden prefix of 'Toulouse'.
Returns a boolean tensor where True means the token should be forbidden.
"""
vocab_size = len(self.tokenizer.get_vocab())
forbidden_mask = torch.zeros(vocab_size, dtype=torch.bool, device=self.model.device)
# Decode the current sequence to find the current word prefix
decoded_sequence = self.tokenizer.decode(seq)
current_word_prefix = self._get_current_word_prefix(decoded_sequence)
# If the current word prefix is empty, we don't need to check anything yet
if not current_word_prefix:
return forbidden_mask
# Get the token IDs for the current word prefix
current_word_ids = self.tokenizer.encode(current_word_prefix, add_special_tokens=False)
# Iterate over all possible next tokens
for token_id in range(vocab_size):
# Create a hypothetical next word by adding the candidate token
hypothetical_word_ids = current_word_ids + [token_id]
hypothetical_word = self.tokenizer.decode(hypothetical_word_ids)
# Check if the hypothetical word is a forbidden prefix (case-insensitive)
if len(hypothetical_word) >= self.min_prefix_len and \
self.forbidden_word.lower().startswith(hypothetical_word.lower()):
forbidden_mask[token_id] = True
return forbidden_mask
def __call__(self, prompt, max_tokens=20):
# Option 2: we use self.tokenizer.apply_chat_template to tokenize the prompt
message = [{"role": "user", "content": prompt}]
encoded = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt")
inputs = (encoded if isinstance(encoded, torch.Tensor) else encoded["input_ids"]).to(self.model.device)
prompt_length = inputs.shape[1]
# Generate tokens one by one
seq = inputs[0].tolist()
for step in range(max_tokens):
input_tensor = torch.tensor([seq], device=self.model.device)
# Get logits from model
with torch.no_grad():
outputs = self.model(input_tensor)
logits = outputs.logits[0, -1, :].clone()
# Get forbidden mask based on current word prefix
forbidden_mask = self._get_forbidden_mask(seq)
# Apply the mask: set forbidden tokens to -inf
logits[forbidden_mask] = float('-inf')
# Greedy decoding
next_token = torch.argmax(logits).item()
# Stop if EOS token
if next_token == self.tokenizer.eos_token_id:
break
seq.append(next_token)
# Extract only the generated tokens (skip the input prompt tokens)
generated_tokens = seq[prompt_length:]
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text.strip()
if __name__ == "__main__":
# NOTE: This block is for testing only. The evaluation server provides model and tokenizer.
# SETUP
MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
la_disparition_generator = LaDisparition(model, tokenizer)
print("Ex 1 (No 'e'):", la_disparition_generator("Who are you?"))
toulouse_sequence_generator = ToulouseSequence(model, tokenizer, debug=True)
print("Ex 2 (No 'Toulouse'):", toulouse_sequence_generator("Where is Toulouse?"))