import json import unicodedata from pathlib import Path from typing import Sequence from .model import ReframrModel META_VOICE_PHRASES = ( "the answer should", "the response should", "a strong answer", "a safe answer", "the safe answer", "the safe move", "the passage", ) PROTOCOL_STARTS = ( "", "", "", "", "", "", ) def load_manifest(path: str | Path) -> dict[str, object]: return json.loads(Path(path).read_text(encoding="utf-8")) def _expected_next_token(model: ReframrModel, expected_text: str) -> str: assert model.tokenizer is not None encoded = model.tokenizer.encode(f" {expected_text}") return encoded[0] if encoded else "" def _normalize_text(text: str) -> str: return " ".join(text.casefold().split()) def _word_ngrams(words: list[str], size: int) -> list[tuple[str, ...]]: if size <= 0 or len(words) < size: return [] return [tuple(words[index : index + size]) for index in range(len(words) - size + 1)] def _distinct_ratio(words: list[str], size: int) -> float: grams = _word_ngrams(words, size) if not grams: return 0.0 return len(set(grams)) / len(grams) def _repetition_ratio(words: list[str], size: int) -> float: grams = _word_ngrams(words, size) if not grams: return 0.0 repeated = len(grams) - len(set(grams)) return repeated / len(grams) def _source_replay_index( sources: Sequence[str] | None, *, ngram_size: int, ) -> list[tuple[str, set[tuple[str, ...]]]]: if not sources: return [] index: list[tuple[str, set[tuple[str, ...]]]] = [] for source in sources: normalized = _normalize_text(str(source)) grams = set(_word_ngrams(normalized.split(), ngram_size)) if grams: index.append((normalized, grams)) return index def _source_replay_overlap( generated: str, replay_index: list[tuple[str, set[tuple[str, ...]]]], *, ngram_size: int, ) -> tuple[float, str]: generated_grams = set(_word_ngrams(_normalize_text(generated).split(), ngram_size)) if not generated_grams or not replay_index: return 0.0, "" best_overlap = 0.0 best_source = "" for normalized_source, source_grams in replay_index: overlap = len(generated_grams & source_grams) / len(generated_grams) if overlap > best_overlap: best_overlap = overlap best_source = normalized_source return best_overlap, best_source def _text_from_replay_row(row: object) -> str: if isinstance(row, str): return row.strip() if not isinstance(row, dict): return "" for field in ("answer", "response", "chosen", "text", "content", "completion"): value = row.get(field) if isinstance(value, str) and value.strip(): return value.strip() if "messages" in row: return _content_to_text(row["messages"]) return "" def load_replay_sources( paths: Sequence[str | Path], *, limit: int = 10_000, ) -> list[str]: sources: list[str] = [] for source_path in paths: path = Path(source_path) if not path.exists(): continue suffix = path.suffix.lower() if suffix == ".jsonl": for line in path.read_text(encoding="utf-8").splitlines(): if limit > 0 and len(sources) >= limit: return sources if not line.strip(): continue text = _text_from_replay_row(json.loads(line)) if text: sources.append(text) continue if suffix == ".json": payload = json.loads(path.read_text(encoding="utf-8")) rows = payload.get("records", payload.get("texts", payload)) if isinstance(payload, dict) else payload if isinstance(rows, list): for row in rows: if limit > 0 and len(sources) >= limit: return sources text = _text_from_replay_row(row) if text: sources.append(text) else: text = _text_from_replay_row(rows) if text: sources.append(text) continue text = path.read_text(encoding="utf-8").strip() if text: sources.append(text) if limit > 0 and len(sources) >= limit: return sources[:limit] return sources[:limit] if limit > 0 else sources def _normalize_phrase_list(value: object) -> list[str]: if not isinstance(value, list): return [] phrases: list[str] = [] for item in value: if isinstance(item, str): phrase = item.strip() if phrase: phrases.append(phrase) return phrases def _normalize_required_groups(value: object) -> list[list[str]]: if not isinstance(value, list): return [] groups: list[list[str]] = [] for raw_group in value: if isinstance(raw_group, list): group = [ str(term).casefold().strip() for term in raw_group if str(term).strip() ] else: term = str(raw_group).casefold().strip() group = [term] if term else [] if group: groups.append(group) return groups def _required_group_summary( normalized_text: str, required_groups: object, ) -> tuple[int, int, float]: groups = _normalize_required_groups(required_groups) hit_count = sum( 1 for group in groups if any(term in normalized_text for term in group) ) group_count = len(groups) coverage = hit_count / group_count if group_count else 0.0 return hit_count, group_count, coverage def _banned_phrase_hit(normalized_text: str, banned_phrases: object) -> bool: return any( _normalize_text(phrase) in normalized_text for phrase in _normalize_phrase_list(banned_phrases) if _normalize_text(phrase) ) def _meta_voice_hit(normalized_text: str) -> bool: return any(phrase in normalized_text for phrase in META_VOICE_PHRASES) def _has_malformed_sentence_start(text: str) -> bool: stripped = text.strip() if not stripped: return True if any(stripped.startswith(protocol) for protocol in PROTOCOL_STARTS): return False leading_quote = False for character in stripped: if character.isspace(): continue category = unicodedata.category(character) if category.startswith(("P", "S")): if character in {"'", '"', "‘", "’", "“", "”"}: leading_quote = True continue if character.isalpha(): if leading_quote: return False return character.islower() return False return False def _quality_gate_passed( *, word_count: int, punctuation_hit: bool, required_group_coverage: float, exact_copy: bool, banned_phrase_hit: bool, meta_voice_hit: bool, malformed_start: bool, repetition_3: float, tool_call_hit: bool, fabricated_tool_result_hit: bool, fabricated_source_hit: bool, source_replay_hit: bool, item: dict[str, object], ) -> bool: blocking_failure = any( ( exact_copy, banned_phrase_hit, meta_voice_hit, malformed_start, fabricated_tool_result_hit, fabricated_source_hit, source_replay_hit, ) ) if bool(item.get("allow_tool_call", False)) and tool_call_hit: return not blocking_failure min_words = int(item.get("min_words", 1)) required_min_coverage = float( item.get( "min_required_group_coverage", 1.0 if item.get("required_groups") else 0.0, ) ) require_punctuation = bool(item.get("require_punctuation", False)) max_repetition_3 = float(item.get("max_repetition_3", 0.35)) if ( _item_contains_source_evidence(item) and required_group_coverage >= required_min_coverage and (punctuation_hit or not require_punctuation) and repetition_3 <= max_repetition_3 ): return not blocking_failure if word_count < min_words: return False if required_group_coverage < required_min_coverage: return False if require_punctuation and not punctuation_hit: return False if repetition_3 > max_repetition_3: return False return not blocking_failure def _item_contains_source_evidence(value: object) -> bool: if isinstance(value, dict): sources = value.get("sources") if isinstance(sources, list) and any(isinstance(source, dict) for source in sources): return True if {"title", "url", "snippet"}.intersection(value.keys()) and ( value.get("title") or value.get("snippet") ): return True return any(_item_contains_source_evidence(child) for child in value.values()) if isinstance(value, list): return any(_item_contains_source_evidence(child) for child in value) return False def _variation_group_summary(samples: list[dict[str, object]]) -> dict[str, dict[str, object]]: grouped: dict[str, list[str]] = {} for sample in samples: key = str(sample.get("variation_key", "")).strip() if not key: continue grouped.setdefault(key, []).append( _normalize_text(str(sample.get("generated_text", ""))) ) summaries: dict[str, dict[str, object]] = {} for key, responses in grouped.items(): sample_count = len(responses) unique_count = len(set(responses)) summaries[key] = { "sample_count": sample_count, "unique_response_count": unique_count, "unique_response_rate": unique_count / sample_count if sample_count else 0.0, "duplicate_response_rate": ( (sample_count - unique_count) / sample_count if sample_count else 0.0 ), } return summaries def _content_to_text(content: object) -> str: if isinstance(content, str): return content.strip() if isinstance(content, list): parts: list[str] = [] for item in content: if isinstance(item, dict): if "text" in item: parts.append(str(item["text"])) elif item.get("type") == "text" and "content" in item: parts.append(str(item["content"])) elif item is not None: parts.append(str(item)) return " ".join(part.strip() for part in parts if part and part.strip()).strip() if content is None: return "" return str(content).strip() def _render_tool_call(call: object) -> str: if not isinstance(call, dict): return f" {str(call).strip()}" function_payload = call.get("function", {}) function = function_payload if isinstance(function_payload, dict) else {} name = str(call.get("name", function.get("name", "tool"))).strip() or "tool" arguments = call.get("arguments", function.get("arguments", {})) if not isinstance(arguments, str): arguments = json.dumps(arguments, ensure_ascii=False, separators=(",", ":")) return f" {name} {arguments}".strip() def _render_tool_result(tool_name: str, result: object) -> list[str]: if isinstance(result, dict): status = str(result.get("status", "ok")).strip() or "ok" if status != "ok": error = str(result.get("error", status)).strip() or status return [f" {tool_name} failed: {error}"] lines = [f" {tool_name} ok"] sources = result.get("sources", []) if isinstance(sources, list): for source in sources: if not isinstance(source, dict): continue title = str(source.get("title", "Source")).strip() or "Source" url = str(source.get("url", "")).strip() snippet = str(source.get("snippet", source.get("text", ""))).strip() lines.append(f" {title} | {url} | {snippet}".strip()) return lines content = _content_to_text(result) return [f" {tool_name} {content or 'empty'}"] def _compose_prompt_context(item: dict[str, object]) -> str: prompt = str(item.get("prompt", "")).strip() system = str(item.get("system", "")).strip() lines: list[str] = [] tool_protocol_seen = False if system: lines.append(system) messages = item.get("messages") if isinstance(messages, list): for message in messages: if not isinstance(message, dict): continue role = str(message.get("role", "")).casefold() content = _content_to_text(message.get("content", "")) if role == "system": if content: lines.append(f"System instruction: {content}") elif role == "user": if content: lines.append(f"User: {content}") elif role == "assistant": if content: lines.append(f"Assistant: {content}") if "" in content: tool_protocol_seen = True tool_calls = message.get("tool_calls", []) if isinstance(tool_calls, list): for call in tool_calls: lines.append(_render_tool_call(call)) tool_protocol_seen = True elif role == "tool": tool_name = str(message.get("name", message.get("tool_call_id", "tool"))) lines.extend(_render_tool_result(tool_name, message.get("content", ""))) tool_protocol_seen = True elif content: lines.append(f"{role.capitalize()}: {content}") if prompt: lines.append(f"User: {prompt}" if isinstance(messages, list) else prompt) tool_results = item.get("tool_results") if isinstance(tool_results, list): for result in tool_results: tool_name = "tool" if isinstance(result, dict): tool_name = str(result.get("name", result.get("tool", "tool"))) lines.extend(_render_tool_result(tool_name, result)) tool_protocol_seen = True elif tool_results: lines.extend(_render_tool_result("tool", tool_results)) tool_protocol_seen = True if tool_protocol_seen: lines.append("") return "\n".join(line for line in lines if line).strip() def _open_ended_score( model: ReframrModel, sample: dict[str, object], *, reasoning_mode: str | None, ) -> dict[str, object]: generated = model.generate_text( str(sample["context"]), max_tokens=int(sample.get("max_tokens", 56)), reasoning_mode=reasoning_mode, ) normalized = _normalize_text(generated) required_groups = [ [str(term).casefold() for term in group] for group in sample.get("required_groups", []) ] satisfied_groups = sum( 1 for group in required_groups if any(term in normalized for term in group) ) group_coverage = ( satisfied_groups / len(required_groups) if required_groups else 0.0 ) punctuation_hit = any(mark in generated for mark in ".,;:?!") min_words = int(sample.get("min_words", 12)) min_word_hit = len(generated.split()) >= min_words banned_phrases = [str(phrase) for phrase in sample.get("banned_phrases", [])] exact_copy = any(normalized == _normalize_text(phrase) for phrase in banned_phrases) novelty_hit = not exact_copy require_punctuation = bool(sample.get("require_punctuation", True)) score_components = [ group_coverage, 1.0 if min_word_hit else 0.0, 1.0 if novelty_hit else 0.0, ] if require_punctuation: score_components.append(1.0 if punctuation_hit else 0.0) return { "section": str(sample["section"]), "context": str(sample["context"]), "generated_text": generated, "group_coverage": group_coverage, "punctuation_hit": punctuation_hit, "min_word_hit": min_word_hit, "exact_copy": exact_copy, "score": sum(score_components) / len(score_components) if score_components else 0.0, } def evaluate_manifest( model: ReframrModel, manifest: dict[str, object], *, reasoning_mode: str | None = None, top_k: int = 5, ) -> dict[str, object]: results: dict[str, object] = { "corpus_name": manifest["name"], "reasoning_mode": reasoning_mode or model.config.default_reasoning_profile, "splits": {}, } splits = manifest["splits"] for split_name in ("memorization", "generalization"): samples = splits[split_name] top1_hits = 0 topk_hits = 0 expected_probabilities = [] for sample in samples: distribution = model.predict_next_token_distribution( sample["context"], reasoning_mode=reasoning_mode, ) ranked = sorted(distribution.items(), key=lambda item: item[1], reverse=True) predicted = ranked[0][0] if ranked else "" top_tokens = [token for token, _ in ranked[:top_k]] expected = _expected_next_token(model, sample["expected"]) expected_probability = distribution.get(expected, 0.0) if predicted == expected: top1_hits += 1 if expected in top_tokens: topk_hits += 1 expected_probabilities.append(expected_probability) sample_count = len(samples) mean_expected_probability = ( sum(expected_probabilities) / sample_count if sample_count else 0.0 ) results["splits"][split_name] = { "sample_count": sample_count, "top1_accuracy": top1_hits / sample_count if sample_count else 0.0, "topk_accuracy": topk_hits / sample_count if sample_count else 0.0, "mean_expected_probability": mean_expected_probability, } open_ended_samples = splits.get("open_ended", []) if open_ended_samples: sample_results = [ _open_ended_score( model, sample, reasoning_mode=reasoning_mode, ) for sample in open_ended_samples ] sample_count = len(sample_results) results["open_ended"] = { "sample_count": sample_count, "mean_score": ( sum(float(sample["score"]) for sample in sample_results) / sample_count if sample_count else 0.0 ), "mean_group_coverage": ( sum(float(sample["group_coverage"]) for sample in sample_results) / sample_count if sample_count else 0.0 ), "punctuation_rate": ( sum(1 for sample in sample_results if bool(sample["punctuation_hit"])) / sample_count if sample_count else 0.0 ), "min_word_rate": ( sum(1 for sample in sample_results if bool(sample["min_word_hit"])) / sample_count if sample_count else 0.0 ), "exact_copy_rate": ( sum(1 for sample in sample_results if bool(sample["exact_copy"])) / sample_count if sample_count else 0.0 ), "samples": sample_results, } return results def benchmark_open_prompts( model: ReframrModel, prompts: list[dict[str, object]], *, reasoning_mode: str | None = None, max_tokens: int = 64, temperature: float = 0.82, top_k: int = 24, top_p: float = 0.92, repetition_penalty: float = 1.18, replay_sources: Sequence[str] | None = None, replay_ngram_size: int = 8, replay_overlap_threshold: float = 0.70, ) -> dict[str, object]: samples: list[dict[str, object]] = [] normalized_replay_ngram_size = max(3, int(replay_ngram_size)) replay_index = _source_replay_index( replay_sources, ngram_size=normalized_replay_ngram_size, ) avoid_texts = list(replay_sources or []) for item in prompts: prompt = str(item["prompt"]) context = _compose_prompt_context(item) generated = model.generate_text( context, max_tokens=max_tokens, reasoning_mode=reasoning_mode, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, avoid_texts=avoid_texts, ) normalized = _normalize_text(generated) banned_phrases = [str(phrase) for phrase in item.get("banned_phrases", [])] exact_copy = any( normalized == _normalize_text(phrase) for phrase in banned_phrases ) words = generated.split() punctuation_hit = any(mark in generated for mark in ".,;:?!") tool_call_hit = "" in generated generated_tool_result_hit = "" in generated generated_source_hit = "" in generated fabricated_tool_result_hit = generated_tool_result_hit and "" not in context fabricated_source_hit = generated_source_hit and "" not in context required_group_hits, required_group_count, required_group_coverage = ( _required_group_summary(normalized, item.get("required_groups", [])) ) source_replay_overlap, source_replay_source = _source_replay_overlap( generated, replay_index, ngram_size=normalized_replay_ngram_size, ) source_replay_hit = ( bool(replay_index) and source_replay_overlap >= float(replay_overlap_threshold) ) banned_hit = _banned_phrase_hit(normalized, item.get("banned_phrases", [])) meta_hit = _meta_voice_hit(normalized) malformed_start = _has_malformed_sentence_start(generated) distinct_2 = _distinct_ratio(words, 2) distinct_3 = _distinct_ratio(words, 3) repetition_3 = _repetition_ratio(words, 3) passed_quality_gate = _quality_gate_passed( word_count=len(words), punctuation_hit=punctuation_hit, required_group_coverage=required_group_coverage, exact_copy=exact_copy, banned_phrase_hit=banned_hit, meta_voice_hit=meta_hit, malformed_start=malformed_start, repetition_3=repetition_3, tool_call_hit=tool_call_hit, fabricated_tool_result_hit=fabricated_tool_result_hit, fabricated_source_hit=fabricated_source_hit, source_replay_hit=source_replay_hit, item=item, ) samples.append( { "prompt": prompt, "context": context, "tags": [str(tag) for tag in item.get("tags", [])], "variation_key": str(item.get("variation_key", "")).strip(), "generated_text": generated, "word_count": len(words), "char_count": len(generated), "punctuation_hit": punctuation_hit, "distinct_2": distinct_2, "distinct_3": distinct_3, "repetition_3": repetition_3, "exact_copy": exact_copy, "banned_phrase_hit": banned_hit, "tool_call_hit": tool_call_hit, "generated_tool_result_hit": generated_tool_result_hit, "generated_source_hit": generated_source_hit, "fabricated_tool_result_hit": fabricated_tool_result_hit, "fabricated_source_hit": fabricated_source_hit, "source_replay_overlap": source_replay_overlap, "source_replay_hit": source_replay_hit, "source_replay_source": source_replay_source, "required_group_hits": required_group_hits, "required_group_count": required_group_count, "required_group_coverage": required_group_coverage, "malformed_start": malformed_start, "meta_voice_hit": meta_hit, "passed_quality_gate": passed_quality_gate, } ) sample_count = len(samples) normalized_responses = [ _normalize_text(str(sample["generated_text"])) for sample in samples ] unique_response_count = len(set(normalized_responses)) exact_copy_count = sum(1 for sample in samples if bool(sample["exact_copy"])) banned_phrase_count = sum( 1 for sample in samples if bool(sample["banned_phrase_hit"]) ) malformed_start_count = sum( 1 for sample in samples if bool(sample["malformed_start"]) ) meta_voice_count = sum(1 for sample in samples if bool(sample["meta_voice_hit"])) tool_call_count = sum(1 for sample in samples if bool(sample["tool_call_hit"])) fabricated_tool_result_count = sum( 1 for sample in samples if bool(sample["fabricated_tool_result_hit"]) ) fabricated_source_count = sum( 1 for sample in samples if bool(sample["fabricated_source_hit"]) ) source_replay_count = sum( 1 for sample in samples if bool(sample["source_replay_hit"]) ) quality_pass_count = sum( 1 for sample in samples if bool(sample["passed_quality_gate"]) ) variation_groups = _variation_group_summary(samples) worst_variation_group_unique_rate = ( min( float(summary["unique_response_rate"]) for summary in variation_groups.values() ) if variation_groups else 1.0 ) required_group_samples = [ sample for sample in samples if int(sample.get("required_group_count", 0)) > 0 ] required_group_sample_count = len(required_group_samples) mean_required_group_coverage = ( sum(float(sample["required_group_coverage"]) for sample in required_group_samples) / required_group_sample_count if required_group_sample_count else 0.0 ) quality_scores = [ quality_pass_count / sample_count if sample_count else 0.0, unique_response_count / sample_count if sample_count else 0.0, mean_required_group_coverage, 1.0 - (exact_copy_count / sample_count if sample_count else 0.0), 1.0 - (banned_phrase_count / sample_count if sample_count else 0.0), 1.0 - (fabricated_tool_result_count / sample_count if sample_count else 0.0), 1.0 - (fabricated_source_count / sample_count if sample_count else 0.0), 1.0 - (source_replay_count / sample_count if sample_count else 0.0), 1.0 - (malformed_start_count / sample_count if sample_count else 0.0), 1.0 - (meta_voice_count / sample_count if sample_count else 0.0), worst_variation_group_unique_rate, ] return { "schema_version": "reframr.open_benchmark.v2", "sample_count": sample_count, "reasoning_mode": reasoning_mode or model.config.default_reasoning_profile, "generation_policy": { "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, }, "mean_word_count": ( sum(int(sample["word_count"]) for sample in samples) / sample_count if sample_count else 0.0 ), "mean_char_count": ( sum(int(sample["char_count"]) for sample in samples) / sample_count if sample_count else 0.0 ), "punctuation_rate": ( sum(1 for sample in samples if bool(sample["punctuation_hit"])) / sample_count if sample_count else 0.0 ), "required_group_sample_count": required_group_sample_count, "mean_required_group_coverage": mean_required_group_coverage, "mean_distinct_2": ( sum(float(sample["distinct_2"]) for sample in samples) / sample_count if sample_count else 0.0 ), "mean_distinct_3": ( sum(float(sample["distinct_3"]) for sample in samples) / sample_count if sample_count else 0.0 ), "mean_repetition_3": ( sum(float(sample["repetition_3"]) for sample in samples) / sample_count if sample_count else 0.0 ), "exact_copy_count": exact_copy_count, "exact_copy_rate": exact_copy_count / sample_count if sample_count else 0.0, "banned_phrase_count": banned_phrase_count, "banned_phrase_rate": ( banned_phrase_count / sample_count if sample_count else 0.0 ), "malformed_start_count": malformed_start_count, "malformed_start_rate": ( malformed_start_count / sample_count if sample_count else 0.0 ), "meta_voice_count": meta_voice_count, "meta_voice_rate": meta_voice_count / sample_count if sample_count else 0.0, "tool_call_count": tool_call_count, "tool_call_rate": tool_call_count / sample_count if sample_count else 0.0, "fabricated_tool_result_count": fabricated_tool_result_count, "fabricated_tool_result_rate": ( fabricated_tool_result_count / sample_count if sample_count else 0.0 ), "fabricated_source_count": fabricated_source_count, "fabricated_source_rate": ( fabricated_source_count / sample_count if sample_count else 0.0 ), "source_replay_count": source_replay_count, "source_replay_rate": ( source_replay_count / sample_count if sample_count else 0.0 ), "replay_ngram_size": normalized_replay_ngram_size, "replay_overlap_threshold": float(replay_overlap_threshold), "quality_pass_count": quality_pass_count, "quality_pass_rate": quality_pass_count / sample_count if sample_count else 0.0, "unique_response_count": unique_response_count, "unique_response_rate": unique_response_count / sample_count if sample_count else 0.0, "duplicate_response_rate": ( (sample_count - unique_response_count) / sample_count if sample_count else 0.0 ), "variation_groups": variation_groups, "worst_variation_group_unique_rate": worst_variation_group_unique_rate, "v2_readiness_score": sum(quality_scores) / len(quality_scores), "samples": samples, }