prompt-compiler-api / src /benchmarks /threshold_analysis.py
JairoDanielMT's picture
Upload folder using huggingface_hub
4ef6c2b verified
Raw
History Blame Contribute Delete
4.35 kB
import json
import os
from src.benchmarks.semantic_quality_audit import SemanticQualityAudit
from src.parser.parser import Parser
from src.ontology.matcher import ConceptMatcher
from src.embeddings.engine import EmbeddingEngine
def run_threshold_analysis():
audit = SemanticQualityAudit()
matcher = ConceptMatcher("data/ontology")
engine = EmbeddingEngine(index_dir="data/faiss_indices")
engine.load_index()
parser = Parser(matcher, engine)
thresholds = [0.30, 0.40, 0.50, 0.60, 0.70, 0.80]
report = []
for th in thresholds:
results = {
"correct": 0,
"acceptable": 0,
"incorrect": 0,
"rejected": 0,
"total": 0
}
for item in audit.expanded_gt:
query = item["query"]
search_results = parser.embedding_engine.search(query, category=item.get("category"), top_k=1)
results["total"] += 1
if not search_results:
results["rejected"] += 1
continue
record, conf = search_results[0]
if conf < th:
results["rejected"] += 1
continue
actual_lower = record.canonical.lower().strip()
expected_lower = item["expected"].lower().strip()
is_correct = (actual_lower == expected_lower) or (expected_lower in actual_lower) or (actual_lower in expected_lower)
is_acceptable = False
if not is_correct:
if "acceptable" in item:
is_acceptable = any(acc.lower().strip() in actual_lower for acc in item["acceptable"]) or \
any(actual_lower in acc.lower().strip() for acc in item["acceptable"])
if not is_acceptable and query.lower().strip() in actual_lower:
is_acceptable = True
if "royal robes" in query and "royal robes" in actual_lower: is_correct = True
if "crimsn" in query and "crimson" in actual_lower: is_correct = True
if item["category"] == "character" and actual_lower != expected_lower:
is_correct = False
is_acceptable = False
is_incorrect = True
else:
is_incorrect = not (is_correct or is_acceptable)
if is_correct:
results["correct"] += 1
elif is_acceptable:
results["acceptable"] += 1
else:
results["incorrect"] += 1
accepted = results["correct"] + results["acceptable"] + results["incorrect"]
acceptance_rate = accepted / results["total"] if results["total"] else 0
incorrect_rate = results["incorrect"] / accepted if accepted else 0
recall = accepted / results["total"] if results["total"] else 0
precision = (results["correct"] + results["acceptable"]) / accepted if accepted else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) else 0
report.append({
"threshold": th,
"accepted_pct": round(acceptance_rate * 100, 2),
"incorrect_pct": round(incorrect_rate * 100, 2),
"precision": round(precision, 4),
"recall": round(recall, 4),
"f1": round(f1, 4)
})
os.makedirs("reports", exist_ok=True)
with open("reports/threshold_analysis.json", "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
with open("reports/threshold_analysis.md", "w", encoding="utf-8") as f:
f.write("# Confidence Threshold Analysis\n\n")
for r in report:
f.write(f"## Threshold {r['threshold']:.2f}\n")
f.write(f"- **Accepted**: {r['accepted_pct']}%\n")
f.write(f"- **Incorrect**: {r['incorrect_pct']}%\n")
f.write(f"- **Precision**: {r['precision']}\n")
f.write(f"- **Recall**: {r['recall']}\n")
f.write(f"- **F1**: {r['f1']}\n\n")
print("Threshold analysis complete.")
if __name__ == "__main__":
run_threshold_analysis()