yammdd's picture
Update app.py
894961b verified
import torch
import tensorflow as tf
import numpy as np
import difflib
from flask import Flask, render_template, request, jsonify
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM
app = Flask(__name__)
MODELS_CONFIG = {
"correction": {
"path": "yammdd/vietnamese-error-correction",
"framework": "pt"
},
"diacritics": {
"path": "yammdd/vietnamese-diacritic-restoration-v2",
"framework": "tf"
}
}
loaded_models = {}
device_pt = "cuda" if torch.cuda.is_available() else "cpu"
for mode, config in MODELS_CONFIG.items():
path = config["path"]
fw = config["framework"]
try:
print(f"Loading {mode}...")
tokenizer = AutoTokenizer.from_pretrained(path)
if fw == "pt":
model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device_pt)
else:
model = TFAutoModelForSeq2SeqLM.from_pretrained(path)
loaded_models[mode] = {
"tokenizer": tokenizer,
"model": model,
"framework": fw
}
except Exception as e:
print(f"Failed to load {mode}: {e}")
def get_similarity(s1, s2):
return difflib.SequenceMatcher(None, s1.lower(), s2.lower()).ratio()
def is_start_char_match(src, tgt):
if not src or not tgt: return False
c1 = src[0].lower()
c2 = tgt[0].lower()
if c1 == c2: return True
if c1 == 'f' and tgt.lower().startswith('ph'): return True
if c1 == 'w' and (tgt.lower().startswith('qu') or c2 == 'ư'): return True
if c1 == 'j' and (tgt.lower().startswith('gi') or c2 == 'd'): return True
if c1 == 'z' and c2 in ['d', 'r', 'v']: return True
if c1 == 'k' and c2 in ['c', 'q']: return True
return False
def smart_alignment(source_words, target_words, target_confidences):
n = len(source_words)
m = len(target_words)
MAX_LOOKBACK = 5
dp = np.zeros((n + 1, m + 1))
for i in range(n + 1): dp[i][0] = i * -1.0
for j in range(m + 1): dp[0][j] = j * -1.0
for i in range(1, n + 1):
for j in range(1, m + 1):
src_word = source_words[i-1]
best_score = dp[i-1][j] - 0.5
score_insert = dp[i][j-1] - 0.5
best_score = max(best_score, score_insert)
for k in range(1, min(j, MAX_LOOKBACK) + 1):
segment_words = target_words[j-k : j]
combined_tgt = " ".join(segment_words)
sim = get_similarity(src_word, combined_tgt)
group_bonus = 0.15 * k if k > 1 else 0
start_char_bonus = 0.0
if is_start_char_match(src_word, combined_tgt):
start_char_bonus = 0.5
match_score = dp[i-1][j-k] + sim + group_bonus + start_char_bonus - 0.2
if src_word.lower() == combined_tgt.lower():
match_score = dp[i-1][j-k] + 2.0
best_score = max(best_score, match_score)
dp[i][j] = best_score
i, j = n, m
aligned_results = []
while i > 0 or j > 0:
src_word = source_words[i-1] if i > 0 else ""
current_score = dp[i][j]
found_match = False
max_k_check = min(j, MAX_LOOKBACK)
if i > 0 and j > 0:
for k in range(max_k_check, 0, -1):
prev_score = dp[i-1][j-k]
segment_words = target_words[j-k : j]
combined_tgt = " ".join(segment_words)
sim = get_similarity(src_word, combined_tgt)
group_bonus = 0.15 * k if k > 1 else 0
start_char_bonus = 0.0
if is_start_char_match(src_word, combined_tgt):
start_char_bonus = 0.5
match_score = prev_score + sim + group_bonus + start_char_bonus - 0.2
if src_word.lower() == combined_tgt.lower():
match_score = prev_score + 2.0
if abs(current_score - match_score) < 0.001:
confs = target_confidences[j-k : j]
avg_conf = sum(confs)/len(confs) if confs else 0.0
type_tag = 'equal' if (k == 1 and src_word.lower() == combined_tgt.lower()) else 'replace'
aligned_results.append({
"original": src_word,
"corrected": combined_tgt,
"confidence": avg_conf * 100,
"type": type_tag
})
i -= 1
j -= k
found_match = True
break
if found_match:
continue
del_score = dp[i-1][j] - 0.5 if i > 0 else -999
if i > 0 and abs(current_score - del_score) < 0.001:
aligned_results.append({
"original": src_word,
"corrected": "",
"confidence": 0.0,
"type": "delete"
})
i -= 1
continue
tgt_word = target_words[j-1] if j > 0 else ""
conf = target_confidences[j-1] if j > 0 else 0.0
aligned_results.append({
"original": "",
"corrected": tgt_word,
"confidence": conf * 100,
"type": "insert"
})
j -= 1
aligned_results.reverse()
return aligned_results
def process_with_confidence(text, mode):
if mode not in loaded_models:
raise ValueError(f"Model {mode} not loaded.")
m_info = loaded_models[mode]
tokenizer = m_info["tokenizer"]
model = m_info["model"]
fw = m_info["framework"]
if fw == "pt":
inputs = tokenizer(text, return_tensors="pt").to(device_pt)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=256, return_dict_in_generate=True, output_scores=True
)
transition_scores = model.compute_transition_scores(
outputs.sequences, outputs.scores, normalize_logits=True
).cpu().numpy()
generated_tokens = outputs.sequences[0].cpu().numpy()
else:
inputs = tokenizer(text, return_tensors="tf")
outputs = model.generate(
**inputs, max_new_tokens=256, return_dict_in_generate=True, output_scores=True
)
transition_scores = model.compute_transition_scores(
outputs.sequences, outputs.scores, normalize_logits=True
).numpy()
generated_tokens = outputs.sequences[0].numpy()
special_tokens = {tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id}
start_index = 0
while start_index < len(generated_tokens) and generated_tokens[start_index] in special_tokens:
start_index += 1
end_index = len(generated_tokens)
for i in range(start_index, len(generated_tokens)):
if generated_tokens[i] in special_tokens:
end_index = i
break
output_ids = generated_tokens[start_index:end_index]
full_text = tokenizer.decode(output_ids, skip_special_tokens=True)
target_words = full_text.split()
if not target_words:
return full_text, []
token_to_word_map = []
for i, token_id in enumerate(output_ids):
if i >= len(transition_scores[0]): break
prob = np.exp(transition_scores[0][i])
decoded_up_to_here = tokenizer.decode(output_ids[:i+1], skip_special_tokens=True)
words_so_far = decoded_up_to_here.split()
word_index = len(words_so_far) - 1 if words_so_far else 0
token_to_word_map.append({'prob': prob, 'word_index': word_index})
word_confidences_map = {}
for item in token_to_word_map:
idx = item['word_index']
if idx not in word_confidences_map: word_confidences_map[idx] = []
word_confidences_map[idx].append(item['prob'])
target_confidences = []
for i in range(len(target_words)):
if i in word_confidences_map:
target_confidences.append(float(np.mean(word_confidences_map[i])))
else:
target_confidences.append(0.0)
input_words = text.split()
aligned_data = smart_alignment(input_words, target_words, target_confidences)
return full_text, aligned_data
@app.route('/')
def index():
return render_template('index.html')
@app.route('/correct', methods=['POST'])
def correct_text():
data = request.get_json()
input_text = data.get('text', '')
mode = data.get('mode', 'correction')
if not input_text.strip():
return jsonify({"result": "", "alignment": []})
try:
generated_text, aligned_data = process_with_confidence(input_text, mode)
return jsonify({
"result": generated_text,
"alignment": aligned_data
})
except Exception as e:
print(f"Error: {e}")
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=False)