Reframr-RFM-v1-Base / reframr /evaluation.py
OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
import json
from pathlib import Path
from .model import ReframrModel
def load_manifest(path: str | Path) -> dict[str, object]:
return json.loads(Path(path).read_text(encoding="utf-8"))
def _expected_next_token(model: ReframrModel, expected_text: str) -> str:
assert model.tokenizer is not None
encoded = model.tokenizer.encode(f" {expected_text}")
return encoded[0] if encoded else ""
def _normalize_text(text: str) -> str:
return " ".join(text.casefold().split())
def _word_ngrams(words: list[str], size: int) -> list[tuple[str, ...]]:
if size <= 0 or len(words) < size:
return []
return [tuple(words[index : index + size]) for index in range(len(words) - size + 1)]
def _distinct_ratio(words: list[str], size: int) -> float:
grams = _word_ngrams(words, size)
if not grams:
return 0.0
return len(set(grams)) / len(grams)
def _repetition_ratio(words: list[str], size: int) -> float:
grams = _word_ngrams(words, size)
if not grams:
return 0.0
repeated = len(grams) - len(set(grams))
return repeated / len(grams)
def _open_ended_score(
model: ReframrModel,
sample: dict[str, object],
*,
reasoning_mode: str | None,
) -> dict[str, object]:
generated = model.generate_text(
str(sample["context"]),
max_tokens=int(sample.get("max_tokens", 56)),
reasoning_mode=reasoning_mode,
)
normalized = _normalize_text(generated)
required_groups = [
[str(term).casefold() for term in group]
for group in sample.get("required_groups", [])
]
satisfied_groups = sum(
1
for group in required_groups
if any(term in normalized for term in group)
)
group_coverage = (
satisfied_groups / len(required_groups) if required_groups else 0.0
)
punctuation_hit = any(mark in generated for mark in ".,;:?!")
min_words = int(sample.get("min_words", 12))
min_word_hit = len(generated.split()) >= min_words
banned_phrases = [str(phrase) for phrase in sample.get("banned_phrases", [])]
exact_copy = any(normalized == _normalize_text(phrase) for phrase in banned_phrases)
novelty_hit = not exact_copy
require_punctuation = bool(sample.get("require_punctuation", True))
score_components = [
group_coverage,
1.0 if min_word_hit else 0.0,
1.0 if novelty_hit else 0.0,
]
if require_punctuation:
score_components.append(1.0 if punctuation_hit else 0.0)
return {
"section": str(sample["section"]),
"context": str(sample["context"]),
"generated_text": generated,
"group_coverage": group_coverage,
"punctuation_hit": punctuation_hit,
"min_word_hit": min_word_hit,
"exact_copy": exact_copy,
"score": sum(score_components) / len(score_components) if score_components else 0.0,
}
def evaluate_manifest(
model: ReframrModel,
manifest: dict[str, object],
*,
reasoning_mode: str | None = None,
top_k: int = 5,
) -> dict[str, object]:
results: dict[str, object] = {
"corpus_name": manifest["name"],
"reasoning_mode": reasoning_mode or model.config.default_reasoning_profile,
"splits": {},
}
splits = manifest["splits"]
for split_name in ("memorization", "generalization"):
samples = splits[split_name]
top1_hits = 0
topk_hits = 0
expected_probabilities = []
for sample in samples:
distribution = model.predict_next_token_distribution(
sample["context"],
reasoning_mode=reasoning_mode,
)
ranked = sorted(distribution.items(), key=lambda item: item[1], reverse=True)
predicted = ranked[0][0] if ranked else ""
top_tokens = [token for token, _ in ranked[:top_k]]
expected = _expected_next_token(model, sample["expected"])
expected_probability = distribution.get(expected, 0.0)
if predicted == expected:
top1_hits += 1
if expected in top_tokens:
topk_hits += 1
expected_probabilities.append(expected_probability)
sample_count = len(samples)
mean_expected_probability = (
sum(expected_probabilities) / sample_count if sample_count else 0.0
)
results["splits"][split_name] = {
"sample_count": sample_count,
"top1_accuracy": top1_hits / sample_count if sample_count else 0.0,
"topk_accuracy": topk_hits / sample_count if sample_count else 0.0,
"mean_expected_probability": mean_expected_probability,
}
open_ended_samples = splits.get("open_ended", [])
if open_ended_samples:
sample_results = [
_open_ended_score(
model,
sample,
reasoning_mode=reasoning_mode,
)
for sample in open_ended_samples
]
sample_count = len(sample_results)
results["open_ended"] = {
"sample_count": sample_count,
"mean_score": (
sum(float(sample["score"]) for sample in sample_results) / sample_count
if sample_count
else 0.0
),
"mean_group_coverage": (
sum(float(sample["group_coverage"]) for sample in sample_results) / sample_count
if sample_count
else 0.0
),
"punctuation_rate": (
sum(1 for sample in sample_results if bool(sample["punctuation_hit"])) / sample_count
if sample_count
else 0.0
),
"min_word_rate": (
sum(1 for sample in sample_results if bool(sample["min_word_hit"])) / sample_count
if sample_count
else 0.0
),
"exact_copy_rate": (
sum(1 for sample in sample_results if bool(sample["exact_copy"])) / sample_count
if sample_count
else 0.0
),
"samples": sample_results,
}
return results
def benchmark_open_prompts(
model: ReframrModel,
prompts: list[dict[str, object]],
*,
reasoning_mode: str | None = None,
max_tokens: int = 64,
temperature: float = 0.82,
top_k: int = 24,
top_p: float = 0.92,
repetition_penalty: float = 1.18,
) -> dict[str, object]:
samples: list[dict[str, object]] = []
for item in prompts:
prompt = str(item["prompt"])
generated = model.generate_text(
prompt,
max_tokens=max_tokens,
reasoning_mode=reasoning_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
words = generated.split()
samples.append(
{
"prompt": prompt,
"tags": [str(tag) for tag in item.get("tags", [])],
"generated_text": generated,
"word_count": len(words),
"char_count": len(generated),
"punctuation_hit": any(mark in generated for mark in ".,;:?!"),
"distinct_2": _distinct_ratio(words, 2),
"distinct_3": _distinct_ratio(words, 3),
"repetition_3": _repetition_ratio(words, 3),
}
)
sample_count = len(samples)
return {
"sample_count": sample_count,
"reasoning_mode": reasoning_mode or model.config.default_reasoning_profile,
"generation_policy": {
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
},
"mean_word_count": (
sum(int(sample["word_count"]) for sample in samples) / sample_count
if sample_count
else 0.0
),
"mean_char_count": (
sum(int(sample["char_count"]) for sample in samples) / sample_count
if sample_count
else 0.0
),
"punctuation_rate": (
sum(1 for sample in samples if bool(sample["punctuation_hit"])) / sample_count
if sample_count
else 0.0
),
"mean_distinct_2": (
sum(float(sample["distinct_2"]) for sample in samples) / sample_count
if sample_count
else 0.0
),
"mean_distinct_3": (
sum(float(sample["distinct_3"]) for sample in samples) / sample_count
if sample_count
else 0.0
),
"mean_repetition_3": (
sum(float(sample["repetition_3"]) for sample in samples) / sample_count
if sample_count
else 0.0
),
"samples": samples,
}