|
|
import torch |
|
|
import tensorflow as tf |
|
|
import numpy as np |
|
|
import difflib |
|
|
from flask import Flask, render_template, request, jsonify |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
MODELS_CONFIG = { |
|
|
"correction": { |
|
|
"path": "yammdd/vietnamese-error-correction", |
|
|
"framework": "pt" |
|
|
}, |
|
|
"diacritics": { |
|
|
"path": "yammdd/vietnamese-diacritic-restoration-v2", |
|
|
"framework": "tf" |
|
|
} |
|
|
} |
|
|
|
|
|
loaded_models = {} |
|
|
device_pt = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
for mode, config in MODELS_CONFIG.items(): |
|
|
path = config["path"] |
|
|
fw = config["framework"] |
|
|
try: |
|
|
print(f"Loading {mode}...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
if fw == "pt": |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device_pt) |
|
|
else: |
|
|
model = TFAutoModelForSeq2SeqLM.from_pretrained(path) |
|
|
loaded_models[mode] = { |
|
|
"tokenizer": tokenizer, |
|
|
"model": model, |
|
|
"framework": fw |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Failed to load {mode}: {e}") |
|
|
|
|
|
def get_similarity(s1, s2): |
|
|
return difflib.SequenceMatcher(None, s1.lower(), s2.lower()).ratio() |
|
|
|
|
|
def is_start_char_match(src, tgt): |
|
|
if not src or not tgt: return False |
|
|
c1 = src[0].lower() |
|
|
c2 = tgt[0].lower() |
|
|
|
|
|
if c1 == c2: return True |
|
|
|
|
|
if c1 == 'f' and tgt.lower().startswith('ph'): return True |
|
|
if c1 == 'w' and (tgt.lower().startswith('qu') or c2 == 'ư'): return True |
|
|
if c1 == 'j' and (tgt.lower().startswith('gi') or c2 == 'd'): return True |
|
|
if c1 == 'z' and c2 in ['d', 'r', 'v']: return True |
|
|
if c1 == 'k' and c2 in ['c', 'q']: return True |
|
|
|
|
|
return False |
|
|
|
|
|
def smart_alignment(source_words, target_words, target_confidences): |
|
|
n = len(source_words) |
|
|
m = len(target_words) |
|
|
|
|
|
MAX_LOOKBACK = 5 |
|
|
|
|
|
dp = np.zeros((n + 1, m + 1)) |
|
|
|
|
|
for i in range(n + 1): dp[i][0] = i * -1.0 |
|
|
for j in range(m + 1): dp[0][j] = j * -1.0 |
|
|
|
|
|
for i in range(1, n + 1): |
|
|
for j in range(1, m + 1): |
|
|
src_word = source_words[i-1] |
|
|
|
|
|
best_score = dp[i-1][j] - 0.5 |
|
|
|
|
|
score_insert = dp[i][j-1] - 0.5 |
|
|
best_score = max(best_score, score_insert) |
|
|
|
|
|
for k in range(1, min(j, MAX_LOOKBACK) + 1): |
|
|
segment_words = target_words[j-k : j] |
|
|
combined_tgt = " ".join(segment_words) |
|
|
|
|
|
sim = get_similarity(src_word, combined_tgt) |
|
|
|
|
|
group_bonus = 0.15 * k if k > 1 else 0 |
|
|
|
|
|
start_char_bonus = 0.0 |
|
|
if is_start_char_match(src_word, combined_tgt): |
|
|
start_char_bonus = 0.5 |
|
|
|
|
|
match_score = dp[i-1][j-k] + sim + group_bonus + start_char_bonus - 0.2 |
|
|
|
|
|
if src_word.lower() == combined_tgt.lower(): |
|
|
match_score = dp[i-1][j-k] + 2.0 |
|
|
|
|
|
best_score = max(best_score, match_score) |
|
|
|
|
|
dp[i][j] = best_score |
|
|
|
|
|
i, j = n, m |
|
|
aligned_results = [] |
|
|
|
|
|
while i > 0 or j > 0: |
|
|
src_word = source_words[i-1] if i > 0 else "" |
|
|
current_score = dp[i][j] |
|
|
|
|
|
found_match = False |
|
|
|
|
|
max_k_check = min(j, MAX_LOOKBACK) |
|
|
if i > 0 and j > 0: |
|
|
for k in range(max_k_check, 0, -1): |
|
|
prev_score = dp[i-1][j-k] |
|
|
segment_words = target_words[j-k : j] |
|
|
combined_tgt = " ".join(segment_words) |
|
|
|
|
|
sim = get_similarity(src_word, combined_tgt) |
|
|
group_bonus = 0.15 * k if k > 1 else 0 |
|
|
|
|
|
start_char_bonus = 0.0 |
|
|
if is_start_char_match(src_word, combined_tgt): |
|
|
start_char_bonus = 0.5 |
|
|
|
|
|
match_score = prev_score + sim + group_bonus + start_char_bonus - 0.2 |
|
|
|
|
|
if src_word.lower() == combined_tgt.lower(): |
|
|
match_score = prev_score + 2.0 |
|
|
|
|
|
if abs(current_score - match_score) < 0.001: |
|
|
confs = target_confidences[j-k : j] |
|
|
avg_conf = sum(confs)/len(confs) if confs else 0.0 |
|
|
|
|
|
type_tag = 'equal' if (k == 1 and src_word.lower() == combined_tgt.lower()) else 'replace' |
|
|
|
|
|
aligned_results.append({ |
|
|
"original": src_word, |
|
|
"corrected": combined_tgt, |
|
|
"confidence": avg_conf * 100, |
|
|
"type": type_tag |
|
|
}) |
|
|
i -= 1 |
|
|
j -= k |
|
|
found_match = True |
|
|
break |
|
|
|
|
|
if found_match: |
|
|
continue |
|
|
|
|
|
del_score = dp[i-1][j] - 0.5 if i > 0 else -999 |
|
|
if i > 0 and abs(current_score - del_score) < 0.001: |
|
|
aligned_results.append({ |
|
|
"original": src_word, |
|
|
"corrected": "", |
|
|
"confidence": 0.0, |
|
|
"type": "delete" |
|
|
}) |
|
|
i -= 1 |
|
|
continue |
|
|
|
|
|
tgt_word = target_words[j-1] if j > 0 else "" |
|
|
conf = target_confidences[j-1] if j > 0 else 0.0 |
|
|
aligned_results.append({ |
|
|
"original": "", |
|
|
"corrected": tgt_word, |
|
|
"confidence": conf * 100, |
|
|
"type": "insert" |
|
|
}) |
|
|
j -= 1 |
|
|
|
|
|
aligned_results.reverse() |
|
|
return aligned_results |
|
|
|
|
|
def process_with_confidence(text, mode): |
|
|
if mode not in loaded_models: |
|
|
raise ValueError(f"Model {mode} not loaded.") |
|
|
|
|
|
m_info = loaded_models[mode] |
|
|
tokenizer = m_info["tokenizer"] |
|
|
model = m_info["model"] |
|
|
fw = m_info["framework"] |
|
|
|
|
|
if fw == "pt": |
|
|
inputs = tokenizer(text, return_tensors="pt").to(device_pt) |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, max_new_tokens=256, return_dict_in_generate=True, output_scores=True |
|
|
) |
|
|
transition_scores = model.compute_transition_scores( |
|
|
outputs.sequences, outputs.scores, normalize_logits=True |
|
|
).cpu().numpy() |
|
|
generated_tokens = outputs.sequences[0].cpu().numpy() |
|
|
else: |
|
|
inputs = tokenizer(text, return_tensors="tf") |
|
|
outputs = model.generate( |
|
|
**inputs, max_new_tokens=256, return_dict_in_generate=True, output_scores=True |
|
|
) |
|
|
transition_scores = model.compute_transition_scores( |
|
|
outputs.sequences, outputs.scores, normalize_logits=True |
|
|
).numpy() |
|
|
generated_tokens = outputs.sequences[0].numpy() |
|
|
|
|
|
special_tokens = {tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id} |
|
|
start_index = 0 |
|
|
while start_index < len(generated_tokens) and generated_tokens[start_index] in special_tokens: |
|
|
start_index += 1 |
|
|
end_index = len(generated_tokens) |
|
|
for i in range(start_index, len(generated_tokens)): |
|
|
if generated_tokens[i] in special_tokens: |
|
|
end_index = i |
|
|
break |
|
|
|
|
|
output_ids = generated_tokens[start_index:end_index] |
|
|
full_text = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
|
target_words = full_text.split() |
|
|
|
|
|
if not target_words: |
|
|
return full_text, [] |
|
|
|
|
|
token_to_word_map = [] |
|
|
for i, token_id in enumerate(output_ids): |
|
|
if i >= len(transition_scores[0]): break |
|
|
prob = np.exp(transition_scores[0][i]) |
|
|
decoded_up_to_here = tokenizer.decode(output_ids[:i+1], skip_special_tokens=True) |
|
|
words_so_far = decoded_up_to_here.split() |
|
|
word_index = len(words_so_far) - 1 if words_so_far else 0 |
|
|
token_to_word_map.append({'prob': prob, 'word_index': word_index}) |
|
|
|
|
|
word_confidences_map = {} |
|
|
for item in token_to_word_map: |
|
|
idx = item['word_index'] |
|
|
if idx not in word_confidences_map: word_confidences_map[idx] = [] |
|
|
word_confidences_map[idx].append(item['prob']) |
|
|
|
|
|
target_confidences = [] |
|
|
for i in range(len(target_words)): |
|
|
if i in word_confidences_map: |
|
|
target_confidences.append(float(np.mean(word_confidences_map[i]))) |
|
|
else: |
|
|
target_confidences.append(0.0) |
|
|
|
|
|
input_words = text.split() |
|
|
|
|
|
aligned_data = smart_alignment(input_words, target_words, target_confidences) |
|
|
|
|
|
return full_text, aligned_data |
|
|
|
|
|
@app.route('/') |
|
|
def index(): |
|
|
return render_template('index.html') |
|
|
|
|
|
@app.route('/correct', methods=['POST']) |
|
|
def correct_text(): |
|
|
data = request.get_json() |
|
|
input_text = data.get('text', '') |
|
|
mode = data.get('mode', 'correction') |
|
|
|
|
|
if not input_text.strip(): |
|
|
return jsonify({"result": "", "alignment": []}) |
|
|
|
|
|
try: |
|
|
generated_text, aligned_data = process_with_confidence(input_text, mode) |
|
|
return jsonify({ |
|
|
"result": generated_text, |
|
|
"alignment": aligned_data |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
app.run(host='0.0.0.0', port=7860, debug=False) |