nathanael-fijalkow's picture
Fix for Transformers v5
6538c21
import gradio as gr
import importlib.util
import json
import torch
import torch.nn.functional as F
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
import threading
import sys
import argparse
import time
# 1. SETUP
EVAL_MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
TIMEOUT_SECONDS = 30
tokenizer = AutoTokenizer.from_pretrained(EVAL_MODEL)
# Set pad token to prevent warnings and ensure proper attention masking
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
EVAL_MODEL,
dtype=torch.float16,
device_map="auto"
)
# Load secret test cases
with open("test_cases.json", "r") as f:
TEST_CASES = json.load(f)
# --- PER-PROMPT REFERENCE SCORING ---
# Load reference scores from CSV (generated by calibrate_logprobs.py from solution.py).
# Each prompt has: unconstrained_logprob (baseline) and reference_delta (solution.py delta).
# Quality = 1 if student is as good or better than solution.py, decreasing for worse.
import csv
import traceback
REFERENCE_SCORES = {} # key: (exercise, prompt_index) β†’ dict
with open("reference_scores.csv", "r") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
key = (row["exercise"], int(row["prompt_index"]))
REFERENCE_SCORES[key] = {
"prompt": row["prompt"],
"unconstrained_logprob": float(row["unconstrained_logprob"]),
"reference_logprob": float(row["reference_logprob"]),
"reference_delta": float(row["reference_delta"]),
}
def compute_mean_logprob(prompt_text, generated_text):
"""
Compute the mean log-probability per token of `generated_text`
conditioned on `prompt_text`, under the unconstrained model.
Uses chat template since the evaluation model is an instruct model.
This measures how "natural" the generated text is: a well-constrained
generator still produces coherent text (high logprob), while a bad one
produces gibberish (low logprob).
Returns: (mean_logprob, n_tokens)
"""
if not generated_text or not generated_text.strip():
return -float('inf'), 0
# Always use chat template: the model is an instruct model, so
# logprobs are meaningful only in the chat context.
message = [{"role": "user", "content": prompt_text}]
encoded = tokenizer.apply_chat_template(
message, add_generation_prompt=True, return_tensors="pt"
)
prompt_ids = (encoded if isinstance(encoded, torch.Tensor) else encoded["input_ids"]).to(model.device)
gen_ids = tokenizer.encode(
generated_text, add_special_tokens=False, return_tensors="pt"
).to(model.device)
full_ids = torch.cat([prompt_ids, gen_ids], dim=1)
prompt_len = prompt_ids.shape[1]
if full_ids.shape[1] <= prompt_len:
return -float('inf'), 0
with torch.no_grad():
outputs = model(full_ids)
logits = outputs.logits
log_probs = F.log_softmax(logits, dim=-1)
total_logprob = 0.0
n_tokens = 0
for i in range(prompt_len, full_ids.shape[1]):
token_id = full_ids[0, i].item()
token_logprob = log_probs[0, i - 1, token_id].item()
total_logprob += token_logprob
n_tokens += 1
mean_logprob = total_logprob / n_tokens if n_tokens > 0 else -float('inf')
return mean_logprob, n_tokens
def compute_quality_score(mean_logprob, exercise_key, prompt_index):
"""
Per-prompt quality score in [0, 1] using reference deltas from solution.py.
Logic:
- Compute student_delta = student_logprob - unconstrained_logprob
- Both student_delta and reference_delta are negative (constrained is worse).
- Quality = 1.0 if student_delta >= reference_delta (student as good or better).
- Quality = student_delta / reference_delta if student is worse, clamped to [0, 1].
(ratio > 1 when student is worse since both are negative, so we use
reference/student to get a value in [0, 1] that decreases as student gets worse).
- A generous margin (3x reference delta) maps to quality = 0.
"""
key = (exercise_key, prompt_index)
if key not in REFERENCE_SCORES:
# Fallback: if no reference data, return 1 for any non-terrible logprob
return 1.0 if mean_logprob > -5.0 else 0.0
ref = REFERENCE_SCORES[key]
unconstrained_lp = ref["unconstrained_logprob"]
ref_delta = ref["reference_delta"] # negative value
student_delta = mean_logprob - unconstrained_lp # negative value
if student_delta >= ref_delta:
# Student is as good or better than reference β†’ quality = 1
return 1.0
if ref_delta == 0:
return 0.0
# Student is worse than reference.
# Linear decay: quality = ref_delta / student_delta
# When student_delta == ref_delta β†’ 1.0
# When student_delta is much worse β†’ approaches 0
# Cap at 3x reference delta for quality = 0
worst_delta = 3.0 * ref_delta # e.g., ref=-0.9 β†’ worst=-2.7
if student_delta <= worst_delta:
return 0.0
# Linear interpolation between ref_delta (quality=1) and worst_delta (quality=0)
quality = (student_delta - worst_delta) / (ref_delta - worst_delta)
return max(0.0, min(1.0, quality))
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException(f"Prompt evaluation timed out ({TIMEOUT_SECONDS}s limit exceeded)")
def run_with_timeout(func, args=(), kwargs=None, timeout_sec=TIMEOUT_SECONDS):
"""Run a function with a timeout."""
if kwargs is None:
kwargs = {}
result = [None]
exception = [None]
def target():
try:
result[0] = func(*args, **kwargs)
except BaseException as e:
exception[0] = e
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
thread.join(timeout=timeout_sec)
if thread.is_alive():
raise TimeoutException(f"Prompt evaluation timed out ({TIMEOUT_SECONDS}s limit exceeded)")
if exception[0] is not None:
raise exception[0]
return result[0]
def strip_prompt_from_output(output, prompt):
"""Remove the prompt from the beginning of the output if present."""
# Normalize whitespace for comparison
output_stripped = output.strip()
prompt_stripped = prompt.strip()
# Check if output starts with the prompt
if output_stripped.startswith(prompt_stripped):
result = output_stripped[len(prompt_stripped):].strip()
return result
return output
def extract_assistant_response(text):
"""Extract only the assistant's response from the chat format output."""
lines = text.split('\n')
result = []
in_assistant = False
for line in lines:
stripped = line.strip()
# Start collecting when we see "assistant"
if stripped == "assistant":
in_assistant = True
continue
# Stop collecting when we see "user" or "system"
if stripped in ("user", "system"):
break
# Collect lines that are part of the assistant response
if in_assistant and stripped:
result.append(line)
return '\n'.join(result).strip()
def test_raw_outputs(debug=False):
"""Test raw model outputs without any mask for debugging."""
print(f"\n{'='*60}")
print("RAW MODEL OUTPUTS")
print(f"{'='*60}\n")
# --- EXERCISE 1 RAW ---
print("### Exercise 1 - Raw Outputs:")
for i, prompt in enumerate(TEST_CASES["exercise_1"]):
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=20,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=None,
pad_token_id=tokenizer.pad_token_id
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
cleaned = strip_prompt_from_output(decoded, prompt)
assistant_response = extract_assistant_response(cleaned)
print(f"{i+1}. {assistant_response}")
except Exception as e:
print(f"{i+1}. ERROR: {str(e)}")
# --- EXERCISE 2 RAW ---
print("\n### Exercise 2 - Raw Outputs:")
for i, prompt in enumerate(TEST_CASES["exercise_2"]):
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=20,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=None,
pad_token_id=tokenizer.pad_token_id
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
cleaned = strip_prompt_from_output(decoded, prompt)
assistant_response = extract_assistant_response(cleaned)
print(f"{i+1}. {assistant_response}")
except Exception as e:
print(f"{i+1}. ERROR: {str(e)}")
def evaluate_submission(file_obj, debug=False):
if file_obj is None:
return "No file provided."
try:
# 2. ISOLATED LOADING
# We use a unique name for each import to avoid namespace collisions
file_path = file_obj if isinstance(file_obj, str) else file_obj.name
# Always print who is being evaluated
print(f"\n{'='*60}")
print(f"EVALUATING: {file_path}")
print(f"{'='*60}\n")
# Clear bytecode cache to prevent "unmarshallable object" errors
from pathlib import Path
import shutil
pycache = Path(file_path).parent / "__pycache__"
if pycache.exists():
shutil.rmtree(pycache, ignore_errors=True)
print("### Cleared bytecode cache.")
# Import with a unique module name each time
module_name = f"student_module_{int(time.time() * 1000000)}"
# Disable bytecode writing to prevent permission issues on temp directories
old_dont_write_bytecode = sys.dont_write_bytecode
sys.dont_write_bytecode = True
try:
spec = importlib.util.spec_from_file_location(module_name, file_path)
student_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = student_module
spec.loader.exec_module(student_module)
except Exception as e:
print(f"ERROR during module exec: {type(e).__name__}: {str(e)}")
traceback.print_exc()
raise
finally:
sys.dont_write_bytecode = old_dont_write_bytecode
report = [f"## Results:\n"]
print("### Loaded student module successfully.")
# --- EXERCISE 1 ---
ex1_passed = 0
ex1_timeout = False
ex1_outputs = []
ex1_quality_scores = []
try:
print("### EXERCISE 1 - La Disparition (No 'e')")
ex1_instance = student_module.LaDisparition(model, tokenizer)
for i, prompt in enumerate(TEST_CASES["exercise_1"]):
try:
print(f"\nTest {i+1}/{len(TEST_CASES['exercise_1'])}")
print(f"Prompt: {prompt}")
# We limit tokens to keep evaluation fast
output = run_with_timeout(
ex1_instance,
args=(prompt,),
kwargs={"max_tokens": 20},
timeout_sec=TIMEOUT_SECONDS
)
# Remove prompt from output to only validate generated text
cleaned_output = strip_prompt_from_output(output, prompt)
print(f"Response: {cleaned_output}")
passed = 'e' not in cleaned_output.lower() and len(cleaned_output.strip()) > 3
# Compute logprob quality score
mean_lp, n_tok = compute_mean_logprob(prompt, cleaned_output)
quality = compute_quality_score(mean_lp, "exercise_1", i) if passed else 0.0
ex1_quality_scores.append(quality)
print(f" Constraint passed: {passed} | mean_logprob: {mean_lp:.3f} | quality: {quality:.2f}")
if passed:
ex1_passed += 1
ex1_outputs.append({
"prompt": prompt, "output": cleaned_output, "passed": passed,
"mean_logprob": mean_lp, "quality": quality
})
if debug:
print(f"Ex1 Test {i+1}: {'βœ“' if passed else 'βœ—'}")
print(f" Prompt: {prompt}")
print(f" Output: {output}")
print(f" mean_logprob={mean_lp:.4f}, quality={quality:.2f}")
print()
except TimeoutException:
ex1_timeout = True
ex1_outputs.append({"prompt": prompt, "output": "TIMEOUT", "passed": False, "mean_logprob": float('-inf'), "quality": 0.0})
ex1_quality_scores.append(0.0)
print(f"Result: βœ— TIMEOUT")
break
ex1_avg_quality = sum(ex1_quality_scores) / len(ex1_quality_scores) if ex1_quality_scores else 0.0
print(f"\nExercise 1 Score: {ex1_passed}/5 | Avg quality: {ex1_avg_quality:.2f}")
if ex1_timeout:
report.append(f" **Ex 1 (No 'e'):** TIMEOUT - evaluation exceeded {TIMEOUT_SECONDS}s limit")
else:
report.append(f" **Ex 1 (No 'e'):** {ex1_passed}/5 correct | Quality: {ex1_avg_quality:.0%}")
if debug:
report.append("\n### Ex 1 Outputs:")
for i, out in enumerate(ex1_outputs):
lp_str = f"logprob={out['mean_logprob']:.2f}" if out['mean_logprob'] != float('-inf') else "logprob=N/A"
report.append(f"{i+1}. {'βœ“' if out['passed'] else 'βœ—'} [{lp_str}, q={out['quality']:.2f}] `{out['output']}`")
except Exception as e:
tb = traceback.format_exc()
print(f"Ex 1 outer exception:\n{tb}")
report.append(f" **Ex 1 Error:** {str(e) or type(e).__name__}\n```\n{tb}\n```")
# --- EXERCISE 2 ---
ex2_passed = 0
ex2_timeout = False
ex2_outputs = []
ex2_quality_scores = []
try:
print("\n### EXERCISE 2 - Toulouse Sequence (No 'Toulouse')")
ex2_instance = student_module.ToulouseSequence(model, tokenizer)
for i, prompt in enumerate(TEST_CASES["exercise_2"]):
try:
print(f"\nTest {i+1}/{len(TEST_CASES['exercise_2'])}")
print(f"Prompt: {prompt}")
output = run_with_timeout(
ex2_instance,
args=(prompt,),
kwargs={"max_tokens": 20},
timeout_sec=TIMEOUT_SECONDS
)
# Remove prompt from output to only validate generated text
cleaned_output = strip_prompt_from_output(output, prompt)
print(f"Response: {cleaned_output}")
passed = "toulouse" not in cleaned_output.lower() and len(cleaned_output.strip()) > 3
# Compute logprob quality score
mean_lp, n_tok = compute_mean_logprob(prompt, cleaned_output)
quality = compute_quality_score(mean_lp, "exercise_2", i) if passed else 0.0
ex2_quality_scores.append(quality)
print(f" Constraint passed: {passed} | mean_logprob: {mean_lp:.3f} | quality: {quality:.2f}")
if passed:
ex2_passed += 1
ex2_outputs.append({
"prompt": prompt, "output": cleaned_output, "passed": passed,
"mean_logprob": mean_lp, "quality": quality
})
if debug:
print(f"Ex2 Test {i+1}: {'βœ“' if passed else 'βœ—'}")
print(f" Prompt: {prompt}")
print(f" Output: {output}")
print(f" mean_logprob={mean_lp:.4f}, quality={quality:.2f}")
print()
except TimeoutException:
ex2_timeout = True
ex2_outputs.append({"prompt": prompt, "output": "TIMEOUT", "passed": False, "mean_logprob": float('-inf'), "quality": 0.0})
ex2_quality_scores.append(0.0)
print(f"Result: βœ— TIMEOUT")
break
ex2_avg_quality = sum(ex2_quality_scores) / len(ex2_quality_scores) if ex2_quality_scores else 0.0
print(f"\nExercise 2 Score: {ex2_passed}/5 | Avg quality: {ex2_avg_quality:.2f}")
if ex2_timeout:
report.append(f" **Ex 2 (No Toulouse):** TIMEOUT - evaluation exceeded {TIMEOUT_SECONDS}s limit")
else:
report.append(f" **Ex 2 (No Toulouse):** {ex2_passed}/5 correct | Quality: {ex2_avg_quality:.0%}")
if debug:
report.append("\n### Ex 2 Outputs:")
for i, out in enumerate(ex2_outputs):
lp_str = f"logprob={out['mean_logprob']:.2f}" if out['mean_logprob'] != float('-inf') else "logprob=N/A"
report.append(f"{i+1}. {'βœ“' if out['passed'] else 'βœ—'} [{lp_str}, q={out['quality']:.2f}] `{out['output']}`")
except Exception as e:
tb = traceback.format_exc()
print(f"Ex 2 outer exception:\n{tb}")
report.append(f" **Ex 2 Error:** {str(e) or type(e).__name__}\n```\n{tb}\n```")
# 3. CLEANUP (Crucial for 200 students!)
del student_module
gc.collect()
torch.cuda.empty_cache()
return "\n".join(report)
except Exception as e:
return f"### System Error during import:\n{str(e)}"
# 4. LAUNCH WITH CONCURRENCY CONTROL
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate lipogram solutions")
parser.add_argument("--local", type=str, help="Path to solution file for local testing")
parser.add_argument("--debug", action="store_true", help="Enable debug output")
parser.add_argument("--raw", action="store_true", help="Test raw model outputs without mask")
args = parser.parse_args()
if args.raw:
# Raw output testing mode
test_raw_outputs()
elif args.local:
# Local testing mode
print(f"\n{'='*60}")
print(f"Testing solution: {args.local}")
print(f"{'='*60}\n")
result = evaluate_submission(args.local, debug=args.debug)
print(f"\n{'='*60}")
print("FINAL REPORT:")
print(f"{'='*60}")
print(result)
else:
# Gradio web interface mode
demo = gr.Interface(
fn=evaluate_submission,
inputs=gr.File(label="Submission File"),
outputs="markdown",
api_name="predict"
)
demo.queue(default_concurrency_limit=1).launch()