Spaces:
Sleeping
Sleeping
| """ | |
| Khmer Legal Bridge - Translation API | |
| ===================================== | |
| Flask application with COMETKiwi-based confidence scoring. | |
| Features: | |
| - Bidirectional EN↔KM translation using fine-tuned NLLB-200 | |
| - Scientific confidence scoring with COMETKiwi | |
| - PDF text extraction | |
| - Privacy-first design (zero retention) | |
| Author: Khmer Legal Bridge Project | |
| License: MIT | |
| """ | |
| from flask import Flask, render_template, request, jsonify | |
| from transformers import AutoModelForSeq2SeqLM, NllbTokenizerFast | |
| import torch | |
| import fitz | |
| import re | |
| import unicodedata | |
| import time | |
| import logging | |
| import os | |
| from sacremoses import MosesPunctNormalizer | |
| # Import confidence scoring module | |
| from confidence_scoring_v2 import ( | |
| TransparencyScorer, | |
| DEFAULT_LEGAL_GLOSSARY, | |
| ConfidenceResult | |
| ) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| # ============================================================================ | |
| # Text Preprocessing | |
| # ============================================================================ | |
| mpn = MosesPunctNormalizer(lang="en") | |
| mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions] | |
| def get_non_printing_char_replacer(replace_by: str = " "): | |
| non_printable_map = { | |
| ord(c): replace_by | |
| for c in (chr(i) for i in range(0x110000)) | |
| if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"} | |
| } | |
| return lambda line: line.translate(non_printable_map) | |
| replace_nonprint = get_non_printing_char_replacer(" ") | |
| def preprocess_text(text: str) -> str: | |
| """Clean and normalize text for translation.""" | |
| clean = mpn.normalize(text) | |
| clean = replace_nonprint(clean) | |
| clean = unicodedata.normalize("NFKC", clean) | |
| return clean | |
| # ============================================================================ | |
| # Model Loading | |
| # ============================================================================ | |
| logger.info("Loading translation model...") | |
| MODEL_ID = "ClaudBarbara/Open_Access_Khmer" | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) | |
| tokenizer = NllbTokenizerFast.from_pretrained(MODEL_ID) | |
| logger.info("Translation model loaded!") | |
| # Configuration | |
| USE_COMET = os.environ.get("USE_COMET", "true").lower() == "true" | |
| USE_DETAILED_SCORING = os.environ.get("DETAILED_SCORING", "true").lower() == "true" | |
| # Initialize confidence scorer (lazy loading for COMETKiwi) | |
| confidence_scorer = None | |
| def get_confidence_scorer(): | |
| """Lazy initialization of confidence scorer.""" | |
| global confidence_scorer | |
| if confidence_scorer is None: | |
| logger.info(f"Initializing confidence scorer (COMETKiwi: {USE_COMET})") | |
| confidence_scorer = TransparencyScorer( | |
| translator_func=translate_simple, | |
| glossary=DEFAULT_LEGAL_GLOSSARY, | |
| use_comet=USE_COMET, | |
| use_back_translation=True, | |
| use_terminology=True | |
| ) | |
| return confidence_scorer | |
| # ============================================================================ | |
| # Translation Functions | |
| # ============================================================================ | |
| def segment_text(text: str, src_lang: str) -> list: | |
| """Segment text into sentences for batch processing.""" | |
| if src_lang == "khm_Khmr": | |
| # Khmer sentence boundaries | |
| sentences = re.split(r'(?<=[។៖])\s*', text) | |
| else: | |
| # English sentence boundaries | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def translate_simple(text: str, src_lang: str, tgt_lang: str) -> str: | |
| """ | |
| Simple translation without confidence scoring. | |
| Used for back-translation verification. | |
| """ | |
| tokenizer.src_lang = src_lang | |
| inputs = tokenizer( | |
| text, | |
| return_tensors='pt', | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), | |
| max_new_tokens=int(32 + 3 * inputs.input_ids.shape[1]), | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def translate_batch(texts: list, src_lang: str, tgt_lang: str) -> list: | |
| """Translate a batch of texts efficiently.""" | |
| if not texts: | |
| return [] | |
| tokenizer.src_lang = src_lang | |
| inputs = tokenizer( | |
| texts, | |
| return_tensors='pt', | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), | |
| max_new_tokens=int(32 + 3 * inputs.input_ids.shape[1]), | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| return tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| def translate_long( | |
| text: str, | |
| src_lang: str, | |
| tgt_lang: str, | |
| batch_size: int = 8, | |
| compute_confidence: bool = True | |
| ) -> tuple: | |
| """ | |
| Translate long text with sentence segmentation and confidence scoring. | |
| Args: | |
| text: Input text | |
| src_lang: Source language code | |
| tgt_lang: Target language code | |
| batch_size: Batch size for processing | |
| compute_confidence: Whether to compute detailed confidence | |
| Returns: | |
| Tuple of (translation, metrics_dict) | |
| """ | |
| start_time = time.time() | |
| # Preprocess | |
| clean_text = preprocess_text(text) | |
| sentences = segment_text(clean_text, src_lang) | |
| if not sentences: | |
| return "", {"error": "No text to translate"} | |
| # Translate in batches | |
| translated_parts = [] | |
| for i in range(0, len(sentences), batch_size): | |
| batch = sentences[i:i + batch_size] | |
| translations = translate_batch(batch, src_lang, tgt_lang) | |
| translated_parts.extend(translations) | |
| result = " ".join(translated_parts) | |
| elapsed = time.time() - start_time | |
| # Compute confidence score | |
| direction = "en2km" if src_lang == "eng_Latn" else "km2en" | |
| if compute_confidence and USE_COMET: | |
| try: | |
| scorer = get_confidence_scorer() | |
| # For long texts, sample representative sentences for scoring | |
| if len(sentences) > 5: | |
| # Score first, middle, and last sentences | |
| sample_indices = [0, len(sentences)//2, -1] | |
| sample_scores = [] | |
| for idx in sample_indices: | |
| src_sent = sentences[idx] | |
| tgt_sent = translated_parts[idx] | |
| conf_result = scorer.score( | |
| src_sent, tgt_sent, direction, | |
| detailed=USE_DETAILED_SCORING | |
| ) | |
| sample_scores.append(conf_result.overall_score) | |
| avg_score = sum(sample_scores) / len(sample_scores) | |
| min_score = min(sample_scores) | |
| # Use most conservative estimate | |
| confidence_score = min(avg_score, min_score + 0.1) | |
| else: | |
| # Score entire translation | |
| conf_result = scorer.score( | |
| clean_text, result, direction, | |
| detailed=USE_DETAILED_SCORING | |
| ) | |
| confidence_score = conf_result.overall_score | |
| # Determine review recommendation | |
| needs_review = confidence_score < 0.75 | |
| quality_level = ( | |
| "excellent" if confidence_score >= 0.85 else | |
| "good" if confidence_score >= 0.70 else | |
| "acceptable" if confidence_score >= 0.55 else | |
| "low" if confidence_score >= 0.40 else | |
| "very_low" | |
| ) | |
| metrics = { | |
| "confidence": round(confidence_score * 100, 1), | |
| "quality_level": quality_level, | |
| "needs_review": needs_review, | |
| "time_seconds": round(elapsed, 2), | |
| "sentences": len(sentences), | |
| "method": "comet_kiwi" | |
| } | |
| except Exception as e: | |
| logger.error(f"Confidence scoring failed: {e}") | |
| # Fallback to lightweight scoring | |
| metrics = compute_lightweight_metrics( | |
| clean_text, result, direction, elapsed, len(sentences) | |
| ) | |
| else: | |
| # Use lightweight scoring | |
| metrics = compute_lightweight_metrics( | |
| clean_text, result, direction, elapsed, len(sentences) | |
| ) | |
| return result, metrics | |
| def compute_lightweight_metrics( | |
| source: str, | |
| translation: str, | |
| direction: str, | |
| elapsed: float, | |
| num_sentences: int | |
| ) -> dict: | |
| """ | |
| Compute lightweight confidence metrics without COMETKiwi. | |
| """ | |
| scorer = get_confidence_scorer() | |
| conf_result = scorer.score_fast(source, translation, direction) | |
| return { | |
| "confidence": round(conf_result.overall_score * 100, 1), | |
| "quality_level": conf_result.quality_level, | |
| "needs_review": conf_result.human_review_recommended, | |
| "time_seconds": round(elapsed, 2), | |
| "sentences": num_sentences, | |
| "method": "lightweight" | |
| } | |
| # ============================================================================ | |
| # PDF Extraction | |
| # ============================================================================ | |
| def extract_pdf_text(pdf_file) -> str: | |
| """Extract text from uploaded PDF file.""" | |
| try: | |
| pdf_bytes = pdf_file.read() | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| doc.close() | |
| return text.strip() | |
| except Exception as e: | |
| logger.error(f"PDF extraction failed: {e}") | |
| return None | |
| # ============================================================================ | |
| # API Routes | |
| # ============================================================================ | |
| def index(): | |
| """Serve the main translation interface.""" | |
| return render_template("index.html") | |
| def translate_endpoint(): | |
| """ | |
| Translation API endpoint. | |
| Request JSON: | |
| - text: str - Text to translate | |
| - direction: str - "en-km" or "km-en" | |
| Response JSON: | |
| - success: bool | |
| - translation: str | |
| - metrics: dict with confidence scores | |
| """ | |
| data = request.json | |
| text = data.get("text", "") | |
| direction = data.get("direction", "en-km") | |
| if direction == "en-km": | |
| src_lang, tgt_lang = "eng_Latn", "khm_Khmr" | |
| else: | |
| src_lang, tgt_lang = "khm_Khmr", "eng_Latn" | |
| try: | |
| result, metrics = translate_long(text, src_lang, tgt_lang) | |
| return jsonify({ | |
| "success": True, | |
| "translation": result, | |
| "metrics": metrics | |
| }) | |
| except Exception as e: | |
| logger.error(f"Translation failed: {e}") | |
| return jsonify({ | |
| "success": False, | |
| "error": str(e) | |
| }) | |
| def upload_pdf(): | |
| """ | |
| PDF upload endpoint. | |
| Accepts multipart form with 'file' field. | |
| Returns extracted text. | |
| """ | |
| if 'file' not in request.files: | |
| return jsonify({"success": False, "error": "No file uploaded"}) | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({"success": False, "error": "No file selected"}) | |
| if not file.filename.lower().endswith('.pdf'): | |
| return jsonify({"success": False, "error": "Only PDF files supported"}) | |
| text = extract_pdf_text(file) | |
| if text: | |
| return jsonify({"success": True, "text": text}) | |
| else: | |
| return jsonify({"success": False, "error": "Could not extract text"}) | |
| def health_check(): | |
| """Health check endpoint for monitoring.""" | |
| return jsonify({ | |
| "status": "healthy", | |
| "model": MODEL_ID, | |
| "comet_enabled": USE_COMET | |
| }) | |
| # ============================================================================ | |
| # Main Entry Point | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port) |