lipogram_private / greedy.py
nathanael-fijalkow's picture
Improved logprob-based scoring
4d8bbd9
"""
Greedy/naive solution for comparison.
- Exercise 1: greedy decoding with token-level 'e' masking (same idea, simpler than beam search)
- Exercise 2: naive approach β€” forbid the first token of "Toulouse" and " Toulouse"
(tokens 'T' and ' T'), which is very aggressive and blocks ALL T-starting words.
"""
from typing import List
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- EXERCISE 1: La disparition (No 'e' or 'E') ---
class LaDisparition:
"""Greedy constrained generation: forbid tokens containing 'e', pick argmax."""
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
self.model = model
self.tokenizer = tokenizer
self.forbidden_token_ids = set()
vocab = self.tokenizer.get_vocab()
for token_id in range(len(vocab)):
decoded = self.tokenizer.decode([token_id])
if 'e' in decoded.lower() or not all(ord(c) < 128 for c in decoded):
self.forbidden_token_ids.add(token_id)
def __call__(self, prompt, max_tokens=20):
message = [{"role": "user", "content": prompt}]
input_ids = self.tokenizer.apply_chat_template(
message, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
prompt_len = input_ids.shape[1]
seq = input_ids[0].tolist()
forbidden_list = list(self.forbidden_token_ids)
for step in range(max_tokens):
input_tensor = torch.tensor([seq], device=self.model.device)
with torch.no_grad():
outputs = self.model(input_tensor)
logits = outputs.logits[0, -1, :].clone()
# Mask forbidden tokens
logits[forbidden_list] = -float('inf')
next_token = torch.argmax(logits).item()
if next_token == self.tokenizer.eos_token_id:
break
seq.append(next_token)
generated_tokens = seq[prompt_len:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# --- EXERCISE 2: The Toulouse Sequence (naive approach) ---
class ToulouseSequence:
"""
Naive approach: forbid the first token of "Toulouse" and " Toulouse".
"Toulouse" tokenizes as [T(68)][oul(9226)][ouse(1368)]
" Toulouse" tokenizes as [ T(312)][oul(9226)][ouse(1368)]
By forbidding tokens 68 ('T') and 312 (' T'), we block the model from
ever starting the word "Toulouse". This is very aggressive: it also blocks
ALL words starting with 'T' (e.g., "The", "This", "That", "They", ...).
"""
def __init__(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer):
self.model = model
self.tokenizer = tokenizer
# Find the first token of "Toulouse" and " Toulouse"
toulouse_ids = self.tokenizer.encode("Toulouse", add_special_tokens=False)
space_toulouse_ids = self.tokenizer.encode(" Toulouse", add_special_tokens=False)
self.forbidden_token_ids = {toulouse_ids[0], space_toulouse_ids[0]}
print(f"[ToulouseSequence naive] Forbidden first tokens: {self.forbidden_token_ids}")
def __call__(self, prompt, max_tokens=20):
message = [{"role": "user", "content": prompt}]
inputs = self.tokenizer.apply_chat_template(
message, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)
prompt_length = inputs.shape[1]
seq = inputs[0].tolist()
forbidden_list = list(self.forbidden_token_ids)
for step in range(max_tokens):
input_tensor = torch.tensor([seq], device=self.model.device)
with torch.no_grad():
outputs = self.model(input_tensor)
logits = outputs.logits[0, -1, :].clone()
# Mask forbidden tokens
logits[forbidden_list] = -float('inf')
next_token = torch.argmax(logits).item()
if next_token == self.tokenizer.eos_token_id:
break
seq.append(next_token)
generated_tokens = seq[prompt_length:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
if __name__ == "__main__":
MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.float16, device_map="auto")
print("=== Exercise 1: La Disparition (no 'e') ===")
ex1 = LaDisparition(model, tokenizer)
for prompt in ["Who is the king of the jungle?", "Name a fruit that is red."]:
result = ex1(prompt)
has_e = 'e' in result.lower()
print(f" Q: {prompt}")
print(f" A: {result}")
print(f" {'βœ— FAIL' if has_e else 'βœ“ PASS'}\n")
print("=== Exercise 2: No Toulouse (naive) ===")
ex2 = ToulouseSequence(model, tokenizer)
for prompt in [
"Where is the headquarters of Airbus located?",
"In which French city can you find the Place du Capitole?",
]:
result = ex2(prompt)
has_toulouse = 'toulouse' in result.lower()
print(f" Q: {prompt}")
print(f" A: {result}")
print(f" {'βœ— FAIL' if has_toulouse else 'βœ“ PASS'}\n")