import json from pathlib import Path from .model import ReframrModel def load_manifest(path: str | Path) -> dict[str, object]: return json.loads(Path(path).read_text(encoding="utf-8")) def _expected_next_token(model: ReframrModel, expected_text: str) -> str: assert model.tokenizer is not None encoded = model.tokenizer.encode(f" {expected_text}") return encoded[0] if encoded else "" def _normalize_text(text: str) -> str: return " ".join(text.casefold().split()) def _word_ngrams(words: list[str], size: int) -> list[tuple[str, ...]]: if size <= 0 or len(words) < size: return [] return [tuple(words[index : index + size]) for index in range(len(words) - size + 1)] def _distinct_ratio(words: list[str], size: int) -> float: grams = _word_ngrams(words, size) if not grams: return 0.0 return len(set(grams)) / len(grams) def _repetition_ratio(words: list[str], size: int) -> float: grams = _word_ngrams(words, size) if not grams: return 0.0 repeated = len(grams) - len(set(grams)) return repeated / len(grams) def _open_ended_score( model: ReframrModel, sample: dict[str, object], *, reasoning_mode: str | None, ) -> dict[str, object]: generated = model.generate_text( str(sample["context"]), max_tokens=int(sample.get("max_tokens", 56)), reasoning_mode=reasoning_mode, ) normalized = _normalize_text(generated) required_groups = [ [str(term).casefold() for term in group] for group in sample.get("required_groups", []) ] satisfied_groups = sum( 1 for group in required_groups if any(term in normalized for term in group) ) group_coverage = ( satisfied_groups / len(required_groups) if required_groups else 0.0 ) punctuation_hit = any(mark in generated for mark in ".,;:?!") min_words = int(sample.get("min_words", 12)) min_word_hit = len(generated.split()) >= min_words banned_phrases = [str(phrase) for phrase in sample.get("banned_phrases", [])] exact_copy = any(normalized == _normalize_text(phrase) for phrase in banned_phrases) novelty_hit = not exact_copy require_punctuation = bool(sample.get("require_punctuation", True)) score_components = [ group_coverage, 1.0 if min_word_hit else 0.0, 1.0 if novelty_hit else 0.0, ] if require_punctuation: score_components.append(1.0 if punctuation_hit else 0.0) return { "section": str(sample["section"]), "context": str(sample["context"]), "generated_text": generated, "group_coverage": group_coverage, "punctuation_hit": punctuation_hit, "min_word_hit": min_word_hit, "exact_copy": exact_copy, "score": sum(score_components) / len(score_components) if score_components else 0.0, } def evaluate_manifest( model: ReframrModel, manifest: dict[str, object], *, reasoning_mode: str | None = None, top_k: int = 5, ) -> dict[str, object]: results: dict[str, object] = { "corpus_name": manifest["name"], "reasoning_mode": reasoning_mode or model.config.default_reasoning_profile, "splits": {}, } splits = manifest["splits"] for split_name in ("memorization", "generalization"): samples = splits[split_name] top1_hits = 0 topk_hits = 0 expected_probabilities = [] for sample in samples: distribution = model.predict_next_token_distribution( sample["context"], reasoning_mode=reasoning_mode, ) ranked = sorted(distribution.items(), key=lambda item: item[1], reverse=True) predicted = ranked[0][0] if ranked else "" top_tokens = [token for token, _ in ranked[:top_k]] expected = _expected_next_token(model, sample["expected"]) expected_probability = distribution.get(expected, 0.0) if predicted == expected: top1_hits += 1 if expected in top_tokens: topk_hits += 1 expected_probabilities.append(expected_probability) sample_count = len(samples) mean_expected_probability = ( sum(expected_probabilities) / sample_count if sample_count else 0.0 ) results["splits"][split_name] = { "sample_count": sample_count, "top1_accuracy": top1_hits / sample_count if sample_count else 0.0, "topk_accuracy": topk_hits / sample_count if sample_count else 0.0, "mean_expected_probability": mean_expected_probability, } open_ended_samples = splits.get("open_ended", []) if open_ended_samples: sample_results = [ _open_ended_score( model, sample, reasoning_mode=reasoning_mode, ) for sample in open_ended_samples ] sample_count = len(sample_results) results["open_ended"] = { "sample_count": sample_count, "mean_score": ( sum(float(sample["score"]) for sample in sample_results) / sample_count if sample_count else 0.0 ), "mean_group_coverage": ( sum(float(sample["group_coverage"]) for sample in sample_results) / sample_count if sample_count else 0.0 ), "punctuation_rate": ( sum(1 for sample in sample_results if bool(sample["punctuation_hit"])) / sample_count if sample_count else 0.0 ), "min_word_rate": ( sum(1 for sample in sample_results if bool(sample["min_word_hit"])) / sample_count if sample_count else 0.0 ), "exact_copy_rate": ( sum(1 for sample in sample_results if bool(sample["exact_copy"])) / sample_count if sample_count else 0.0 ), "samples": sample_results, } return results def benchmark_open_prompts( model: ReframrModel, prompts: list[dict[str, object]], *, reasoning_mode: str | None = None, max_tokens: int = 64, temperature: float = 0.82, top_k: int = 24, top_p: float = 0.92, repetition_penalty: float = 1.18, ) -> dict[str, object]: samples: list[dict[str, object]] = [] for item in prompts: prompt = str(item["prompt"]) generated = model.generate_text( prompt, max_tokens=max_tokens, reasoning_mode=reasoning_mode, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, ) words = generated.split() samples.append( { "prompt": prompt, "tags": [str(tag) for tag in item.get("tags", [])], "generated_text": generated, "word_count": len(words), "char_count": len(generated), "punctuation_hit": any(mark in generated for mark in ".,;:?!"), "distinct_2": _distinct_ratio(words, 2), "distinct_3": _distinct_ratio(words, 3), "repetition_3": _repetition_ratio(words, 3), } ) sample_count = len(samples) return { "sample_count": sample_count, "reasoning_mode": reasoning_mode or model.config.default_reasoning_profile, "generation_policy": { "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, }, "mean_word_count": ( sum(int(sample["word_count"]) for sample in samples) / sample_count if sample_count else 0.0 ), "mean_char_count": ( sum(int(sample["char_count"]) for sample in samples) / sample_count if sample_count else 0.0 ), "punctuation_rate": ( sum(1 for sample in samples if bool(sample["punctuation_hit"])) / sample_count if sample_count else 0.0 ), "mean_distinct_2": ( sum(float(sample["distinct_2"]) for sample in samples) / sample_count if sample_count else 0.0 ), "mean_distinct_3": ( sum(float(sample["distinct_3"]) for sample in samples) / sample_count if sample_count else 0.0 ), "mean_repetition_3": ( sum(float(sample["repetition_3"]) for sample in samples) / sample_count if sample_count else 0.0 ), "samples": samples, }