afridialeval / src /eval_structural.py
millicentochieng's picture
Upload folder using huggingface_hub
e2b8b61 verified
"""Structural fidelity evaluation.
Checks whether the pipeline preserves structural properties across stages:
Encoded DAS β†’ Localized DAS β†’ Decoded dialogue
Metrics:
1. Turn count preservation (encoded β†’ localized β†’ decoded)
2. Speaker role preservation (encoded β†’ localized)
3. Function preservation (are communicative intents carried through?)
"""
import json
import re
import sys
from pathlib import Path
from typing import Any, Dict, List
def load_json(path: str) -> List[Dict[str, Any]]:
return json.loads(Path(path).read_text(encoding="utf-8"))
def normalize_speaker(speaker: Any) -> str:
s = str(speaker).strip().lower()
s = re.sub(r"^speaker_?", "", s)
if s in ("1", "a"):
return "A"
if s in ("2", "b"):
return "B"
return s.upper()
def normalize_speaker_role(speaker: Any) -> str:
"""Normalize to role identity (A/B) β€” treats named speakers by position."""
s = str(speaker).strip().lower()
s = re.sub(r"^speaker_?", "", s)
if s in ("1", "a"):
return "A"
if s in ("2", "b"):
return "B"
return s.upper()
def extract_function_names(functions_field: Any) -> List[str]:
"""Extract top-level function names from a functions field."""
if isinstance(functions_field, list):
text = "; ".join(str(f) for f in functions_field)
else:
text = str(functions_field)
return re.findall(r"(\w+)\(", text)
def evaluate_dialogue(
dialogue_id: Any,
encoded_das: List[Dict],
localized_das: List[Dict],
decoded_turns: List[str],
) -> Dict[str, Any]:
"""Evaluate structural fidelity for one dialogue."""
result: Dict[str, Any] = {
"dialogue_id": dialogue_id,
"encoded_turns": len(encoded_das),
"localized_turns": len(localized_das),
"decoded_turns": len(decoded_turns),
"issues": [],
}
# 1. Turn count
if len(encoded_das) != len(localized_das):
result["issues"].append(
f"Turn count mismatch: encoded={len(encoded_das)}, localized={len(localized_das)}"
)
if len(encoded_das) != len(decoded_turns):
result["issues"].append(
f"Turn count mismatch: encoded={len(encoded_das)}, decoded={len(decoded_turns)}"
)
result["turn_count_preserved"] = (
len(encoded_das) == len(localized_das) == len(decoded_turns)
)
# 2. Speaker roles β€” check alternation pattern is preserved, not exact names
encoded_speakers = [normalize_speaker(t.get("speaker_id", "")) for t in encoded_das]
localized_speakers = [normalize_speaker(t.get("speaker_id", "")) for t in localized_das]
# Build role mapping: first unique speaker = A, second = B
def to_role_sequence(speakers: List[str]) -> List[str]:
mapping: Dict[str, str] = {}
role_counter = 0
roles = []
for s in speakers:
if s not in mapping:
mapping[s] = chr(ord("A") + role_counter)
role_counter += 1
roles.append(mapping[s])
return roles
encoded_roles = to_role_sequence(encoded_speakers)
localized_roles = to_role_sequence(localized_speakers)
speaker_mismatches = []
for i, (er, lr) in enumerate(zip(encoded_roles, localized_roles)):
if er != lr:
speaker_mismatches.append(
f"Turn {i+1}: encoded_role={er} ({encoded_speakers[i]}), "
f"localized_role={lr} ({localized_speakers[i]})"
)
if speaker_mismatches:
result["issues"].append(
f"Speaker role mismatches: {speaker_mismatches}"
)
result["speaker_roles_preserved"] = len(speaker_mismatches) == 0
# 3. Communicative intent (function names)
encoded_funcs = [extract_function_names(t.get("functions", "")) for t in encoded_das]
localized_funcs = [extract_function_names(t.get("functions", "")) for t in localized_das]
intent_mismatches = []
for i, (ef, lf) in enumerate(zip(encoded_funcs, localized_funcs)):
if set(ef) != set(lf):
intent_mismatches.append({
"turn": i + 1,
"encoded": ef,
"localized": lf,
"missing": list(set(ef) - set(lf)),
"added": list(set(lf) - set(ef)),
})
result["intent_mismatches"] = intent_mismatches
result["intents_preserved"] = len(intent_mismatches) == 0
result["intent_preservation_rate"] = (
1.0 - len(intent_mismatches) / max(len(encoded_funcs), 1)
)
result["fully_faithful"] = (
result["turn_count_preserved"]
and result["speaker_roles_preserved"]
and result["intents_preserved"]
)
return result
def run_evaluation(
encoded_path: str,
localized_path: str,
decoded_path: str,
label: str = "",
) -> List[Dict[str, Any]]:
encoded = load_json(encoded_path)
localized = load_json(localized_path)
decoded = load_json(decoded_path)
results = []
for i, (enc, loc, dec) in enumerate(zip(encoded, localized, decoded)):
dialogue_id = enc.get("id", loc.get("dialogue_id", i + 1))
decoded_turns = dec.get("decoded_swahili", [])
result = evaluate_dialogue(
dialogue_id=dialogue_id,
encoded_das=enc["das_encoding"],
localized_das=loc["localized_das"],
decoded_turns=decoded_turns,
)
results.append(result)
# Summary
n = len(results)
turn_ok = sum(1 for r in results if r["turn_count_preserved"])
speaker_ok = sum(1 for r in results if r["speaker_roles_preserved"])
intent_ok = sum(1 for r in results if r["intents_preserved"])
fully_ok = sum(1 for r in results if r["fully_faithful"])
avg_intent_rate = sum(r["intent_preservation_rate"] for r in results) / max(n, 1)
print(f"\n{'=' * 60}")
print(f"Structural Fidelity: {label}")
print(f"{'=' * 60}")
print(f" Dialogues evaluated: {n}")
print(f" Turn count preserved: {turn_ok}/{n}")
print(f" Speaker roles preserved: {speaker_ok}/{n}")
print(f" All intents preserved: {intent_ok}/{n}")
print(f" Avg intent preservation: {avg_intent_rate:.1%}")
print(f" Fully faithful: {fully_ok}/{n}")
# Show issues
for r in results:
if not r["fully_faithful"]:
print(f"\n ⚠ Dialogue {r['dialogue_id']}:")
for issue in r["issues"]:
print(f" - {issue}")
for im in r["intent_mismatches"]:
print(
f" - Turn {im['turn']}: "
f"encoded={im['encoded']} β†’ localized={im['localized']}"
)
return results
def main() -> None:
encoded_path = "data/encoded/dailydialog_encoded.json"
regions = {
"Kenya - Nairobi": "swahili_kenya___nairobi",
"Tanzania - Zanzibar": "swahili_tanzania___zanzibar",
}
models = {
"gpt-5.1": "gpt_5_1",
"qwen-3.5-122b": "qwen_3_5_122b",
"gemma-3-27b-it": "gemma_3_27b_it",
}
all_results = {}
for region_name, region_tag in regions.items():
for model_name, model_tag in models.items():
label = f"{region_name} Γ— {model_name}"
loc_path = f"data/localized/dailydialog_{region_tag}_{model_tag}_localized.json"
dec_path = f"data/decoded/dailydialog_{region_tag}_{model_tag}_decoded.json"
if not Path(loc_path).exists() or not Path(dec_path).exists():
print(f"\n SKIP {label}: files not found")
continue
results = run_evaluation(encoded_path, loc_path, dec_path, label)
all_results[label] = results
# Overall summary table
print(f"\n{'=' * 60}")
print("SUMMARY TABLE")
print(f"{'=' * 60}")
print(f"{'Config':<40} {'Turns':>6} {'Speak':>6} {'Intent':>6} {'Full':>6}")
print("-" * 64)
for label, results in all_results.items():
n = len(results)
t = sum(1 for r in results if r["turn_count_preserved"])
s = sum(1 for r in results if r["speaker_roles_preserved"])
i = sum(1 for r in results if r["intents_preserved"])
f = sum(1 for r in results if r["fully_faithful"])
print(f"{label:<40} {t}/{n:>4} {s}/{n:>4} {i}/{n:>4} {f}/{n:>4}")
if __name__ == "__main__":
main()