ClaudBarbara's picture
Update app.py
77f086b verified
"""
Khmer Legal Bridge - Translation API
=====================================
Flask application with COMETKiwi-based confidence scoring.
Features:
- Bidirectional EN↔KM translation using fine-tuned NLLB-200
- Scientific confidence scoring with COMETKiwi
- PDF text extraction
- Privacy-first design (zero retention)
Author: Khmer Legal Bridge Project
License: MIT
"""
from flask import Flask, render_template, request, jsonify
from transformers import AutoModelForSeq2SeqLM, NllbTokenizerFast
import torch
import fitz
import re
import unicodedata
import time
import logging
import os
from sacremoses import MosesPunctNormalizer
# Import confidence scoring module
from confidence_scoring_v2 import (
TransparencyScorer,
DEFAULT_LEGAL_GLOSSARY,
ConfidenceResult
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# ============================================================================
# Text Preprocessing
# ============================================================================
mpn = MosesPunctNormalizer(lang="en")
mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]
def get_non_printing_char_replacer(replace_by: str = " "):
non_printable_map = {
ord(c): replace_by
for c in (chr(i) for i in range(0x110000))
if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
}
return lambda line: line.translate(non_printable_map)
replace_nonprint = get_non_printing_char_replacer(" ")
def preprocess_text(text: str) -> str:
"""Clean and normalize text for translation."""
clean = mpn.normalize(text)
clean = replace_nonprint(clean)
clean = unicodedata.normalize("NFKC", clean)
return clean
# ============================================================================
# Model Loading
# ============================================================================
logger.info("Loading translation model...")
MODEL_ID = "ClaudBarbara/Open_Access_Khmer"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
tokenizer = NllbTokenizerFast.from_pretrained(MODEL_ID)
logger.info("Translation model loaded!")
# Configuration
USE_COMET = os.environ.get("USE_COMET", "true").lower() == "true"
USE_DETAILED_SCORING = os.environ.get("DETAILED_SCORING", "true").lower() == "true"
# Initialize confidence scorer (lazy loading for COMETKiwi)
confidence_scorer = None
def get_confidence_scorer():
"""Lazy initialization of confidence scorer."""
global confidence_scorer
if confidence_scorer is None:
logger.info(f"Initializing confidence scorer (COMETKiwi: {USE_COMET})")
confidence_scorer = TransparencyScorer(
translator_func=translate_simple,
glossary=DEFAULT_LEGAL_GLOSSARY,
use_comet=USE_COMET,
use_back_translation=True,
use_terminology=True
)
return confidence_scorer
# ============================================================================
# Translation Functions
# ============================================================================
def segment_text(text: str, src_lang: str) -> list:
"""Segment text into sentences for batch processing."""
if src_lang == "khm_Khmr":
# Khmer sentence boundaries
sentences = re.split(r'(?<=[។៖])\s*', text)
else:
# English sentence boundaries
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
def translate_simple(text: str, src_lang: str, tgt_lang: str) -> str:
"""
Simple translation without confidence scoring.
Used for back-translation verification.
"""
tokenizer.src_lang = src_lang
inputs = tokenizer(
text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = model.generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_new_tokens=int(32 + 3 * inputs.input_ids.shape[1]),
num_beams=4,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def translate_batch(texts: list, src_lang: str, tgt_lang: str) -> list:
"""Translate a batch of texts efficiently."""
if not texts:
return []
tokenizer.src_lang = src_lang
inputs = tokenizer(
texts,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = model.generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_new_tokens=int(32 + 3 * inputs.input_ids.shape[1]),
num_beams=4,
early_stopping=True
)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
def translate_long(
text: str,
src_lang: str,
tgt_lang: str,
batch_size: int = 8,
compute_confidence: bool = True
) -> tuple:
"""
Translate long text with sentence segmentation and confidence scoring.
Args:
text: Input text
src_lang: Source language code
tgt_lang: Target language code
batch_size: Batch size for processing
compute_confidence: Whether to compute detailed confidence
Returns:
Tuple of (translation, metrics_dict)
"""
start_time = time.time()
# Preprocess
clean_text = preprocess_text(text)
sentences = segment_text(clean_text, src_lang)
if not sentences:
return "", {"error": "No text to translate"}
# Translate in batches
translated_parts = []
for i in range(0, len(sentences), batch_size):
batch = sentences[i:i + batch_size]
translations = translate_batch(batch, src_lang, tgt_lang)
translated_parts.extend(translations)
result = " ".join(translated_parts)
elapsed = time.time() - start_time
# Compute confidence score
direction = "en2km" if src_lang == "eng_Latn" else "km2en"
if compute_confidence and USE_COMET:
try:
scorer = get_confidence_scorer()
# For long texts, sample representative sentences for scoring
if len(sentences) > 5:
# Score first, middle, and last sentences
sample_indices = [0, len(sentences)//2, -1]
sample_scores = []
for idx in sample_indices:
src_sent = sentences[idx]
tgt_sent = translated_parts[idx]
conf_result = scorer.score(
src_sent, tgt_sent, direction,
detailed=USE_DETAILED_SCORING
)
sample_scores.append(conf_result.overall_score)
avg_score = sum(sample_scores) / len(sample_scores)
min_score = min(sample_scores)
# Use most conservative estimate
confidence_score = min(avg_score, min_score + 0.1)
else:
# Score entire translation
conf_result = scorer.score(
clean_text, result, direction,
detailed=USE_DETAILED_SCORING
)
confidence_score = conf_result.overall_score
# Determine review recommendation
needs_review = confidence_score < 0.75
quality_level = (
"excellent" if confidence_score >= 0.85 else
"good" if confidence_score >= 0.70 else
"acceptable" if confidence_score >= 0.55 else
"low" if confidence_score >= 0.40 else
"very_low"
)
metrics = {
"confidence": round(confidence_score * 100, 1),
"quality_level": quality_level,
"needs_review": needs_review,
"time_seconds": round(elapsed, 2),
"sentences": len(sentences),
"method": "comet_kiwi"
}
except Exception as e:
logger.error(f"Confidence scoring failed: {e}")
# Fallback to lightweight scoring
metrics = compute_lightweight_metrics(
clean_text, result, direction, elapsed, len(sentences)
)
else:
# Use lightweight scoring
metrics = compute_lightweight_metrics(
clean_text, result, direction, elapsed, len(sentences)
)
return result, metrics
def compute_lightweight_metrics(
source: str,
translation: str,
direction: str,
elapsed: float,
num_sentences: int
) -> dict:
"""
Compute lightweight confidence metrics without COMETKiwi.
"""
scorer = get_confidence_scorer()
conf_result = scorer.score_fast(source, translation, direction)
return {
"confidence": round(conf_result.overall_score * 100, 1),
"quality_level": conf_result.quality_level,
"needs_review": conf_result.human_review_recommended,
"time_seconds": round(elapsed, 2),
"sentences": num_sentences,
"method": "lightweight"
}
# ============================================================================
# PDF Extraction
# ============================================================================
def extract_pdf_text(pdf_file) -> str:
"""Extract text from uploaded PDF file."""
try:
pdf_bytes = pdf_file.read()
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
text = ""
for page in doc:
text += page.get_text()
doc.close()
return text.strip()
except Exception as e:
logger.error(f"PDF extraction failed: {e}")
return None
# ============================================================================
# API Routes
# ============================================================================
@app.route("/")
def index():
"""Serve the main translation interface."""
return render_template("index.html")
@app.route("/translate", methods=["POST"])
def translate_endpoint():
"""
Translation API endpoint.
Request JSON:
- text: str - Text to translate
- direction: str - "en-km" or "km-en"
Response JSON:
- success: bool
- translation: str
- metrics: dict with confidence scores
"""
data = request.json
text = data.get("text", "")
direction = data.get("direction", "en-km")
if direction == "en-km":
src_lang, tgt_lang = "eng_Latn", "khm_Khmr"
else:
src_lang, tgt_lang = "khm_Khmr", "eng_Latn"
try:
result, metrics = translate_long(text, src_lang, tgt_lang)
return jsonify({
"success": True,
"translation": result,
"metrics": metrics
})
except Exception as e:
logger.error(f"Translation failed: {e}")
return jsonify({
"success": False,
"error": str(e)
})
@app.route("/upload-pdf", methods=["POST"])
def upload_pdf():
"""
PDF upload endpoint.
Accepts multipart form with 'file' field.
Returns extracted text.
"""
if 'file' not in request.files:
return jsonify({"success": False, "error": "No file uploaded"})
file = request.files['file']
if file.filename == '':
return jsonify({"success": False, "error": "No file selected"})
if not file.filename.lower().endswith('.pdf'):
return jsonify({"success": False, "error": "Only PDF files supported"})
text = extract_pdf_text(file)
if text:
return jsonify({"success": True, "text": text})
else:
return jsonify({"success": False, "error": "Could not extract text"})
@app.route("/health", methods=["GET"])
def health_check():
"""Health check endpoint for monitoring."""
return jsonify({
"status": "healthy",
"model": MODEL_ID,
"comet_enabled": USE_COMET
})
# ============================================================================
# Main Entry Point
# ============================================================================
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)