lipogram_private / calibrate_logprobs.py
nathanael-fijalkow's picture
Improved logprob-based scoring
4d8bbd9
"""
Calibration script: compute logprobs for reference solution outputs
vs unconstrained model outputs to design a scoring function.
"""
import torch
import torch.nn.functional as F
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
EVAL_MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(EVAL_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(EVAL_MODEL, dtype=torch.float16, device_map="auto")
with open("test_cases.json", "r") as f:
TEST_CASES = json.load(f)
def compute_chat_logprobs(model, tokenizer, prompt, generated_text):
"""
Compute logprobs using chat template (works for both exercises).
The prompt is formatted as a chat message, generated_text is the response.
Returns:
mean_logprob: mean log-prob per generated token
total_logprob: sum of log-probs
n_tokens: number of generated tokens
per_token: list of (token_str, logprob) pairs
"""
if not generated_text or not generated_text.strip():
return -float('inf'), 0.0, 0, []
message = [{"role": "user", "content": prompt}]
prompt_ids = tokenizer.apply_chat_template(
message, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
prompt_len = prompt_ids.shape[1]
gen_ids = tokenizer.encode(
generated_text, add_special_tokens=False, return_tensors="pt"
).to(model.device)
full_ids = torch.cat([prompt_ids, gen_ids], dim=1)
if full_ids.shape[1] <= prompt_len:
return -float('inf'), 0.0, 0, []
with torch.no_grad():
outputs = model(full_ids)
logits = outputs.logits
log_probs = F.log_softmax(logits, dim=-1)
per_token = []
total_logprob = 0.0
n_tokens = 0
for i in range(prompt_len, full_ids.shape[1]):
token_id = full_ids[0, i].item()
token_logprob = log_probs[0, i - 1, token_id].item()
token_str = tokenizer.decode([token_id])
per_token.append((token_str, token_logprob))
total_logprob += token_logprob
n_tokens += 1
mean_logprob = total_logprob / n_tokens if n_tokens > 0 else -float('inf')
return mean_logprob, total_logprob, n_tokens, per_token
def generate_unconstrained_chat(model, tokenizer, prompt, max_tokens=20):
"""Generate unconstrained text using chat template (for both exercises)."""
message = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(model.device)
attention_mask = torch.ones_like(inputs)
prompt_length = inputs.shape[1]
with torch.no_grad():
output = model.generate(
inputs,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id
)
generated_tokens = output[0][prompt_length:]
return tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
# ---- Load and run the reference solution ----
import importlib.util
import sys
import time
module_name = f"solution_module_{int(time.time())}"
spec = importlib.util.spec_from_file_location(module_name, "solution.py")
solution = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solution
spec.loader.exec_module(solution)
print("\n" + "="*80)
print("EXERCISE 1: La Disparition (no 'e')")
print("="*80)
ex1_instance = solution.LaDisparition(model, tokenizer)
ex1_results = []
for i, prompt in enumerate(TEST_CASES["exercise_1"]):
# Generate constrained output
constrained_output = ex1_instance(prompt, max_tokens=20)
# Strip prompt from output
if constrained_output.startswith(prompt):
constrained_gen = constrained_output[len(prompt):].strip()
else:
constrained_gen = constrained_output.strip()
# Generate unconstrained output (chat template for instruct model)
unconstrained_gen = generate_unconstrained_chat(model, tokenizer, prompt, max_tokens=20)
# Compute logprobs using chat template (matches how the model should be used)
c_mean, c_total, c_ntok, c_per = compute_chat_logprobs(model, tokenizer, prompt, constrained_gen)
# Compute logprobs for unconstrained output
u_mean, u_total, u_ntok, u_per = compute_chat_logprobs(model, tokenizer, prompt, unconstrained_gen)
delta = c_mean - u_mean # will be negative (constrained is worse)
print(f"\nTest {i+1}: {prompt}")
print(f" Unconstrained: {unconstrained_gen}")
print(f" mean_logprob={u_mean:.4f}, n_tokens={u_ntok}")
print(f" Constrained: {constrained_gen}")
print(f" mean_logprob={c_mean:.4f}, n_tokens={c_ntok}")
print(f" Delta (constrained - unconstrained): {delta:.4f}")
ex1_results.append({
"prompt": prompt,
"constrained_gen": constrained_gen,
"unconstrained_gen": unconstrained_gen,
"c_mean_logprob": c_mean,
"u_mean_logprob": u_mean,
"delta_mean_logprob": delta,
})
print(f"\n--- Exercise 1 Summary ---")
deltas_1 = [r["delta_mean_logprob"] for r in ex1_results]
c_means_1 = [r["c_mean_logprob"] for r in ex1_results]
u_means_1 = [r["u_mean_logprob"] for r in ex1_results]
print(f" Unconstrained mean logprobs: {[f'{x:.3f}' for x in u_means_1]}")
print(f" Constrained mean logprobs: {[f'{x:.3f}' for x in c_means_1]}")
print(f" Deltas: {[f'{x:.3f}' for x in deltas_1]}")
print(f" Mean delta: {sum(deltas_1)/len(deltas_1):.4f}")
print(f" Worst delta: {min(deltas_1):.4f}")
print("\n" + "="*80)
print("EXERCISE 2: Toulouse Sequence (no 'Toulouse')")
print("="*80)
ex2_instance = solution.ToulouseSequence(model, tokenizer)
ex2_results = []
for i, prompt in enumerate(TEST_CASES["exercise_2"]):
# Generate constrained output
constrained_gen = ex2_instance(prompt, max_tokens=20)
# Generate unconstrained output (chat format)
unconstrained_gen = generate_unconstrained_chat(model, tokenizer, prompt, max_tokens=20)
# Compute logprobs (chat format)
c_mean, c_total, c_ntok, c_per = compute_chat_logprobs(model, tokenizer, prompt, constrained_gen)
u_mean, u_total, u_ntok, u_per = compute_chat_logprobs(model, tokenizer, prompt, unconstrained_gen)
delta = c_mean - u_mean
print(f"\nTest {i+1}: {prompt}")
print(f" Unconstrained: {unconstrained_gen}")
print(f" mean_logprob={u_mean:.4f}, n_tokens={u_ntok}")
print(f" Constrained: {constrained_gen}")
print(f" mean_logprob={c_mean:.4f}, n_tokens={c_ntok}")
print(f" Delta (constrained - unconstrained): {delta:.4f}")
ex2_results.append({
"prompt": prompt,
"constrained_gen": constrained_gen,
"unconstrained_gen": unconstrained_gen,
"c_mean_logprob": c_mean,
"u_mean_logprob": u_mean,
"delta_mean_logprob": delta,
})
print(f"\n--- Exercise 2 Summary ---")
deltas_2 = [r["delta_mean_logprob"] for r in ex2_results]
c_means_2 = [r["c_mean_logprob"] for r in ex2_results]
u_means_2 = [r["u_mean_logprob"] for r in ex2_results]
print(f" Unconstrained mean logprobs: {[f'{x:.3f}' for x in u_means_2]}")
print(f" Constrained mean logprobs: {[f'{x:.3f}' for x in c_means_2]}")
print(f" Deltas: {[f'{x:.3f}' for x in deltas_2]}")
print(f" Mean delta: {sum(deltas_2)/len(deltas_2):.4f}")
print(f" Worst delta: {min(deltas_2):.4f}")
print("\n" + "="*80)
print("OVERALL RECOMMENDATION")
print("="*80)
all_deltas = deltas_1 + deltas_2
print(f"All deltas: {[f'{x:.3f}' for x in all_deltas]}")
print(f"Global mean delta: {sum(all_deltas)/len(all_deltas):.4f}")
print(f"Global worst delta: {min(all_deltas):.4f}")
# ---- Save reference scores to CSV ----
import csv
csv_path = "reference_scores.csv"
with open(csv_path, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow([
"exercise", "prompt_index", "prompt",
"unconstrained_logprob", "reference_logprob", "reference_delta"
])
for i, r in enumerate(ex1_results):
writer.writerow([
"exercise_1", i, r["prompt"],
f"{r['u_mean_logprob']:.6f}",
f"{r['c_mean_logprob']:.6f}",
f"{r['delta_mean_logprob']:.6f}",
])
for i, r in enumerate(ex2_results):
writer.writerow([
"exercise_2", i, r["prompt"],
f"{r['u_mean_logprob']:.6f}",
f"{r['c_mean_logprob']:.6f}",
f"{r['delta_mean_logprob']:.6f}",
])
print(f"\nReference scores saved to {csv_path}")