AcharO's picture
deploy: FastAPI + mount_gradio_app pattern for /rewrite + Gradio UI
9c0aba1
"""JuaKazi Correction API — HTTP routing only."""
import time
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .audit import log as log_audit
from .ml_rewriter import ml_rewrite
from .rules_engine import apply_rules_on_spans, build_reason
from .schemas import RewriteRequest, RewriteResponse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from eval.correction_evaluator import SemanticPreservationMetrics
SEMANTIC_THRESHOLD = 0.70
app = FastAPI(title="JuaKazi Correction Engine (hybrid)", version="0.3")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
semantic_metrics = SemanticPreservationMetrics()
@app.post("/rewrite", response_model=RewriteResponse)
def rewrite(req: RewriteRequest):
t0 = time.time()
rewritten, edits, matched_rules, skipped = apply_rules_on_spans(
req.text, req.lang, flags=req.flags or None
)
source = "rules"
ml_info = None
semantic_score = None
if rewritten != req.text:
score = semantic_metrics.calculate_composite_preservation_score(req.text, rewritten)
semantic_score = score["composite_score"]
if semantic_score < SEMANTIC_THRESHOLD:
rewritten, edits, source, semantic_score = req.text, [], "preserved", 1.0
if matched_rules == 0 and source != "preserved":
ml_out = ml_rewrite(req.text, lang=req.lang, num_return_sequences=3)
ml_score = semantic_metrics.calculate_composite_preservation_score(req.text, ml_out["best"])
if ml_score["composite_score"] < SEMANTIC_THRESHOLD:
rewritten, source, semantic_score = req.text, "preserved", 1.0
else:
rewritten = ml_out["best"]
source = "ml"
semantic_score = ml_score["composite_score"]
ml_info = ml_out
edits.append({"from": req.text, "to": rewritten, "severity": "ml_fallback", "tags": "", "reason": "ML rewrite"})
latency_ms = int((time.time() - t0) * 1000)
confidence = {"rules": 0.85, "ml": 0.60, "preserved": 0.95}.get(source, 0.85)
needs_review = source == "ml" or len(edits) == 0
reason = build_reason(source, edits, skipped)
has_bias_detected = any(e.get("severity") == "replace" for e in edits)
response = {
"id": req.id,
"original_text": req.text,
"rewrite": rewritten,
"edits": edits,
"confidence": confidence,
"needs_review": needs_review,
"source": source,
"reason": reason,
"semantic_score": semantic_score,
"skipped_context": skipped or None,
"has_bias_detected": has_bias_detected,
}
log_audit({
"request": req.dict(),
"response": response,
"model_info": ml_info or {"model": "rulepack-v0.3"},
"latency_ms": latency_ms,
"region_dialect": req.region_dialect or "unknown",
})
return response
@app.post("/rewrite/batch")
def rewrite_batch(items: list):
return [rewrite(RewriteRequest(**it)) for it in items]