| """JuaKazi Correction API — HTTP routing only.""" |
|
|
| import time |
|
|
| from fastapi import FastAPI |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from .audit import log as log_audit |
| from .ml_rewriter import ml_rewrite |
| from .rules_engine import apply_rules_on_spans, build_reason |
| from .schemas import RewriteRequest, RewriteResponse |
|
|
| import sys |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from eval.correction_evaluator import SemanticPreservationMetrics |
|
|
| SEMANTIC_THRESHOLD = 0.70 |
|
|
| app = FastAPI(title="JuaKazi Correction Engine (hybrid)", version="0.3") |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| semantic_metrics = SemanticPreservationMetrics() |
|
|
|
|
| @app.post("/rewrite", response_model=RewriteResponse) |
| def rewrite(req: RewriteRequest): |
| t0 = time.time() |
| rewritten, edits, matched_rules, skipped = apply_rules_on_spans( |
| req.text, req.lang, flags=req.flags or None |
| ) |
| source = "rules" |
| ml_info = None |
| semantic_score = None |
|
|
| if rewritten != req.text: |
| score = semantic_metrics.calculate_composite_preservation_score(req.text, rewritten) |
| semantic_score = score["composite_score"] |
| if semantic_score < SEMANTIC_THRESHOLD: |
| rewritten, edits, source, semantic_score = req.text, [], "preserved", 1.0 |
|
|
| if matched_rules == 0 and source != "preserved": |
| ml_out = ml_rewrite(req.text, lang=req.lang, num_return_sequences=3) |
| ml_score = semantic_metrics.calculate_composite_preservation_score(req.text, ml_out["best"]) |
| if ml_score["composite_score"] < SEMANTIC_THRESHOLD: |
| rewritten, source, semantic_score = req.text, "preserved", 1.0 |
| else: |
| rewritten = ml_out["best"] |
| source = "ml" |
| semantic_score = ml_score["composite_score"] |
| ml_info = ml_out |
| edits.append({"from": req.text, "to": rewritten, "severity": "ml_fallback", "tags": "", "reason": "ML rewrite"}) |
|
|
| latency_ms = int((time.time() - t0) * 1000) |
| confidence = {"rules": 0.85, "ml": 0.60, "preserved": 0.95}.get(source, 0.85) |
| needs_review = source == "ml" or len(edits) == 0 |
| reason = build_reason(source, edits, skipped) |
| has_bias_detected = any(e.get("severity") == "replace" for e in edits) |
|
|
| response = { |
| "id": req.id, |
| "original_text": req.text, |
| "rewrite": rewritten, |
| "edits": edits, |
| "confidence": confidence, |
| "needs_review": needs_review, |
| "source": source, |
| "reason": reason, |
| "semantic_score": semantic_score, |
| "skipped_context": skipped or None, |
| "has_bias_detected": has_bias_detected, |
| } |
|
|
| log_audit({ |
| "request": req.dict(), |
| "response": response, |
| "model_info": ml_info or {"model": "rulepack-v0.3"}, |
| "latency_ms": latency_ms, |
| "region_dialect": req.region_dialect or "unknown", |
| }) |
| return response |
|
|
|
|
| @app.post("/rewrite/batch") |
| def rewrite_batch(items: list): |
| return [rewrite(RewriteRequest(**it)) for it in items] |
|
|