import torch import tensorflow as tf import numpy as np import difflib from flask import Flask, render_template, request, jsonify from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM app = Flask(__name__) MODELS_CONFIG = { "correction": { "path": "yammdd/vietnamese-error-correction", "framework": "pt" }, "diacritics": { "path": "yammdd/vietnamese-diacritic-restoration-v2", "framework": "tf" } } loaded_models = {} device_pt = "cuda" if torch.cuda.is_available() else "cpu" for mode, config in MODELS_CONFIG.items(): path = config["path"] fw = config["framework"] try: print(f"Loading {mode}...") tokenizer = AutoTokenizer.from_pretrained(path) if fw == "pt": model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device_pt) else: model = TFAutoModelForSeq2SeqLM.from_pretrained(path) loaded_models[mode] = { "tokenizer": tokenizer, "model": model, "framework": fw } except Exception as e: print(f"Failed to load {mode}: {e}") def get_similarity(s1, s2): return difflib.SequenceMatcher(None, s1.lower(), s2.lower()).ratio() def is_start_char_match(src, tgt): if not src or not tgt: return False c1 = src[0].lower() c2 = tgt[0].lower() if c1 == c2: return True if c1 == 'f' and tgt.lower().startswith('ph'): return True if c1 == 'w' and (tgt.lower().startswith('qu') or c2 == 'ư'): return True if c1 == 'j' and (tgt.lower().startswith('gi') or c2 == 'd'): return True if c1 == 'z' and c2 in ['d', 'r', 'v']: return True if c1 == 'k' and c2 in ['c', 'q']: return True return False def smart_alignment(source_words, target_words, target_confidences): n = len(source_words) m = len(target_words) MAX_LOOKBACK = 5 dp = np.zeros((n + 1, m + 1)) for i in range(n + 1): dp[i][0] = i * -1.0 for j in range(m + 1): dp[0][j] = j * -1.0 for i in range(1, n + 1): for j in range(1, m + 1): src_word = source_words[i-1] best_score = dp[i-1][j] - 0.5 score_insert = dp[i][j-1] - 0.5 best_score = max(best_score, score_insert) for k in range(1, min(j, MAX_LOOKBACK) + 1): segment_words = target_words[j-k : j] combined_tgt = " ".join(segment_words) sim = get_similarity(src_word, combined_tgt) group_bonus = 0.15 * k if k > 1 else 0 start_char_bonus = 0.0 if is_start_char_match(src_word, combined_tgt): start_char_bonus = 0.5 match_score = dp[i-1][j-k] + sim + group_bonus + start_char_bonus - 0.2 if src_word.lower() == combined_tgt.lower(): match_score = dp[i-1][j-k] + 2.0 best_score = max(best_score, match_score) dp[i][j] = best_score i, j = n, m aligned_results = [] while i > 0 or j > 0: src_word = source_words[i-1] if i > 0 else "" current_score = dp[i][j] found_match = False max_k_check = min(j, MAX_LOOKBACK) if i > 0 and j > 0: for k in range(max_k_check, 0, -1): prev_score = dp[i-1][j-k] segment_words = target_words[j-k : j] combined_tgt = " ".join(segment_words) sim = get_similarity(src_word, combined_tgt) group_bonus = 0.15 * k if k > 1 else 0 start_char_bonus = 0.0 if is_start_char_match(src_word, combined_tgt): start_char_bonus = 0.5 match_score = prev_score + sim + group_bonus + start_char_bonus - 0.2 if src_word.lower() == combined_tgt.lower(): match_score = prev_score + 2.0 if abs(current_score - match_score) < 0.001: confs = target_confidences[j-k : j] avg_conf = sum(confs)/len(confs) if confs else 0.0 type_tag = 'equal' if (k == 1 and src_word.lower() == combined_tgt.lower()) else 'replace' aligned_results.append({ "original": src_word, "corrected": combined_tgt, "confidence": avg_conf * 100, "type": type_tag }) i -= 1 j -= k found_match = True break if found_match: continue del_score = dp[i-1][j] - 0.5 if i > 0 else -999 if i > 0 and abs(current_score - del_score) < 0.001: aligned_results.append({ "original": src_word, "corrected": "", "confidence": 0.0, "type": "delete" }) i -= 1 continue tgt_word = target_words[j-1] if j > 0 else "" conf = target_confidences[j-1] if j > 0 else 0.0 aligned_results.append({ "original": "", "corrected": tgt_word, "confidence": conf * 100, "type": "insert" }) j -= 1 aligned_results.reverse() return aligned_results def process_with_confidence(text, mode): if mode not in loaded_models: raise ValueError(f"Model {mode} not loaded.") m_info = loaded_models[mode] tokenizer = m_info["tokenizer"] model = m_info["model"] fw = m_info["framework"] if fw == "pt": inputs = tokenizer(text, return_tensors="pt").to(device_pt) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, return_dict_in_generate=True, output_scores=True ) transition_scores = model.compute_transition_scores( outputs.sequences, outputs.scores, normalize_logits=True ).cpu().numpy() generated_tokens = outputs.sequences[0].cpu().numpy() else: inputs = tokenizer(text, return_tensors="tf") outputs = model.generate( **inputs, max_new_tokens=256, return_dict_in_generate=True, output_scores=True ) transition_scores = model.compute_transition_scores( outputs.sequences, outputs.scores, normalize_logits=True ).numpy() generated_tokens = outputs.sequences[0].numpy() special_tokens = {tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id} start_index = 0 while start_index < len(generated_tokens) and generated_tokens[start_index] in special_tokens: start_index += 1 end_index = len(generated_tokens) for i in range(start_index, len(generated_tokens)): if generated_tokens[i] in special_tokens: end_index = i break output_ids = generated_tokens[start_index:end_index] full_text = tokenizer.decode(output_ids, skip_special_tokens=True) target_words = full_text.split() if not target_words: return full_text, [] token_to_word_map = [] for i, token_id in enumerate(output_ids): if i >= len(transition_scores[0]): break prob = np.exp(transition_scores[0][i]) decoded_up_to_here = tokenizer.decode(output_ids[:i+1], skip_special_tokens=True) words_so_far = decoded_up_to_here.split() word_index = len(words_so_far) - 1 if words_so_far else 0 token_to_word_map.append({'prob': prob, 'word_index': word_index}) word_confidences_map = {} for item in token_to_word_map: idx = item['word_index'] if idx not in word_confidences_map: word_confidences_map[idx] = [] word_confidences_map[idx].append(item['prob']) target_confidences = [] for i in range(len(target_words)): if i in word_confidences_map: target_confidences.append(float(np.mean(word_confidences_map[i]))) else: target_confidences.append(0.0) input_words = text.split() aligned_data = smart_alignment(input_words, target_words, target_confidences) return full_text, aligned_data @app.route('/') def index(): return render_template('index.html') @app.route('/correct', methods=['POST']) def correct_text(): data = request.get_json() input_text = data.get('text', '') mode = data.get('mode', 'correction') if not input_text.strip(): return jsonify({"result": "", "alignment": []}) try: generated_text, aligned_data = process_with_confidence(input_text, mode) return jsonify({ "result": generated_text, "alignment": aligned_data }) except Exception as e: print(f"Error: {e}") return jsonify({"error": str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=False)