yammdd commited on
Commit
e2c3629
·
verified ·
1 Parent(s): b004c22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -152
app.py CHANGED
@@ -1,153 +1,159 @@
1
- from flask import Flask, render_template, request, jsonify
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM
3
- import torch
4
- import tensorflow as tf
5
- import numpy as np
6
- import os
7
-
8
- app = Flask(__name__)
9
-
10
- MODELS_CONFIG = {
11
- "correction": {"path": "vietnamese-error-correction", "framework": "pt"},
12
- "diacritics": {"path": "vietnamese-diacritic-restoration-v2", "framework": "tf"}
13
- }
14
-
15
- loaded_models = {}
16
-
17
- print("Đang khởi tạo các models...")
18
- device_pt = "cuda" if torch.cuda.is_available() else "cpu"
19
-
20
- for mode, config in MODELS_CONFIG.items():
21
- path = config["path"]
22
- fw = config["framework"]
23
- try:
24
- print(f"Loading model {mode} ({fw}) từ {path}...")
25
- tokenizer = AutoTokenizer.from_pretrained(path)
26
-
27
- if fw == "pt":
28
- model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device_pt)
29
- else:
30
- model = TFAutoModelForSeq2SeqLM.from_pretrained(path)
31
-
32
- loaded_models[mode] = {
33
- "tokenizer": tokenizer,
34
- "model": model,
35
- "framework": fw
36
- }
37
- print(f"Model {mode} đã sẵn sàng!")
38
- except Exception as e:
39
- print(f"Lỗi khi load model {mode}: {e}")
40
-
41
- def process_with_confidence(text, mode):
42
- if mode not in loaded_models:
43
- raise ValueError(f"Model {mode} chưa được load.")
44
-
45
- m_info = loaded_models[mode]
46
- tokenizer = m_info["tokenizer"]
47
- model = m_info["model"]
48
- fw = m_info["framework"]
49
-
50
- if fw == "pt":
51
- inputs = tokenizer(text, return_tensors="pt").to(device_pt)
52
- else:
53
- inputs = tokenizer(text, return_tensors="tf")
54
-
55
- if fw == "pt":
56
- with torch.no_grad():
57
- outputs = model.generate(
58
- **inputs,
59
- max_new_tokens=256,
60
- return_dict_in_generate=True,
61
- output_scores=True
62
- )
63
- transition_scores = model.compute_transition_scores(
64
- outputs.sequences, outputs.scores, normalize_logits=True
65
- )
66
- transition_scores = transition_scores.cpu().numpy()
67
- generated_tokens = outputs.sequences[0].cpu().numpy()
68
- else:
69
- outputs = model.generate(
70
- **inputs,
71
- max_new_tokens=256,
72
- return_dict_in_generate=True,
73
- output_scores=True
74
- )
75
- transition_scores = model.compute_transition_scores(
76
- outputs.sequences, outputs.scores, normalize_logits=True
77
- )
78
- transition_scores = transition_scores.numpy()
79
- generated_tokens = outputs.sequences[0].numpy()
80
-
81
- special_tokens = {tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id}
82
-
83
- start_index = 0
84
- while start_index < len(generated_tokens) and generated_tokens[start_index] in special_tokens:
85
- start_index += 1
86
-
87
- end_index = len(generated_tokens)
88
- for i in range(start_index, len(generated_tokens)):
89
- if generated_tokens[i] in special_tokens:
90
- end_index = i
91
- break
92
-
93
- output_ids = generated_tokens[start_index:end_index]
94
- full_text = tokenizer.decode(output_ids, skip_special_tokens=True)
95
- target_words = full_text.split()
96
-
97
- if not target_words:
98
- return full_text, []
99
-
100
- token_to_word_map = []
101
- for i, token_id in enumerate(output_ids):
102
- if i >= len(transition_scores[0]): break
103
-
104
- log_prob = transition_scores[0][i]
105
- prob = np.exp(log_prob)
106
-
107
- decoded_up_to_here = tokenizer.decode(output_ids[:i+1], skip_special_tokens=True)
108
- words_so_far = decoded_up_to_here.split()
109
- word_index = len(words_so_far) - 1 if words_so_far else 0
110
-
111
- token_to_word_map.append({'prob': prob, 'word_index': word_index})
112
-
113
- word_confidences = {}
114
- for item in token_to_word_map:
115
- idx = item['word_index']
116
- if idx not in word_confidences: word_confidences[idx] = []
117
- word_confidences[idx].append(item['prob'])
118
-
119
- confidence_list = []
120
- for i in range(len(target_words)):
121
- if i in word_confidences:
122
- probs = word_confidences[i]
123
- confidence_list.append(float(np.mean(probs)))
124
- else:
125
- confidence_list.append(0.0)
126
-
127
- return full_text, confidence_list
128
-
129
- @app.route('/')
130
- def index():
131
- return render_template('index.html')
132
-
133
- @app.route('/correct', methods=['POST'])
134
- def correct_text():
135
- data = request.get_json()
136
- input_text = data.get('text', '')
137
- mode = data.get('mode', 'correction')
138
-
139
- if not input_text.strip():
140
- return jsonify({"result": "", "confidences": []})
141
-
142
- try:
143
- generated_text, confidences = process_with_confidence(input_text, mode)
144
- return jsonify({
145
- "result": generated_text,
146
- "confidences": confidences
147
- })
148
- except Exception as e:
149
- print(f"Error: {e}")
150
- return jsonify({"error": str(e)}), 500
151
-
152
- if __name__ == '__main__':
 
 
 
 
 
 
153
  app.run(host='0.0.0.0', port=7860, debug=False)
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM
3
+ import torch
4
+ import tensorflow as tf
5
+ import numpy as np
6
+ import os
7
+
8
+ app = Flask(__name__)
9
+
10
+ MODELS_CONFIG = {
11
+ "correction": {
12
+ "path": "yammdd/vietnamese-error-correction",
13
+ "framework": "pt"
14
+ },
15
+ "diacritics": {
16
+ "path": "yammdd/vietnamese-diacritic-restoration-v2",
17
+ "framework": "tf"
18
+ }
19
+ }
20
+
21
+ loaded_models = {}
22
+
23
+ print("Đang khởi tạo các models...")
24
+ device_pt = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ for mode, config in MODELS_CONFIG.items():
27
+ path = config["path"]
28
+ fw = config["framework"]
29
+ try:
30
+ print(f"Loading model {mode} ({fw}) từ {path}...")
31
+ tokenizer = AutoTokenizer.from_pretrained(path)
32
+
33
+ if fw == "pt":
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device_pt)
35
+ else:
36
+ model = TFAutoModelForSeq2SeqLM.from_pretrained(path)
37
+
38
+ loaded_models[mode] = {
39
+ "tokenizer": tokenizer,
40
+ "model": model,
41
+ "framework": fw
42
+ }
43
+ print(f"Model {mode} đã sẵn sàng!")
44
+ except Exception as e:
45
+ print(f"Lỗi khi load model {mode}: {e}")
46
+
47
+ def process_with_confidence(text, mode):
48
+ if mode not in loaded_models:
49
+ raise ValueError(f"Model {mode} chưa được load.")
50
+
51
+ m_info = loaded_models[mode]
52
+ tokenizer = m_info["tokenizer"]
53
+ model = m_info["model"]
54
+ fw = m_info["framework"]
55
+
56
+ if fw == "pt":
57
+ inputs = tokenizer(text, return_tensors="pt").to(device_pt)
58
+ else:
59
+ inputs = tokenizer(text, return_tensors="tf")
60
+
61
+ if fw == "pt":
62
+ with torch.no_grad():
63
+ outputs = model.generate(
64
+ **inputs,
65
+ max_new_tokens=256,
66
+ return_dict_in_generate=True,
67
+ output_scores=True
68
+ )
69
+ transition_scores = model.compute_transition_scores(
70
+ outputs.sequences, outputs.scores, normalize_logits=True
71
+ )
72
+ transition_scores = transition_scores.cpu().numpy()
73
+ generated_tokens = outputs.sequences[0].cpu().numpy()
74
+ else:
75
+ outputs = model.generate(
76
+ **inputs,
77
+ max_new_tokens=256,
78
+ return_dict_in_generate=True,
79
+ output_scores=True
80
+ )
81
+ transition_scores = model.compute_transition_scores(
82
+ outputs.sequences, outputs.scores, normalize_logits=True
83
+ )
84
+ transition_scores = transition_scores.numpy()
85
+ generated_tokens = outputs.sequences[0].numpy()
86
+
87
+ special_tokens = {tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id}
88
+
89
+ start_index = 0
90
+ while start_index < len(generated_tokens) and generated_tokens[start_index] in special_tokens:
91
+ start_index += 1
92
+
93
+ end_index = len(generated_tokens)
94
+ for i in range(start_index, len(generated_tokens)):
95
+ if generated_tokens[i] in special_tokens:
96
+ end_index = i
97
+ break
98
+
99
+ output_ids = generated_tokens[start_index:end_index]
100
+ full_text = tokenizer.decode(output_ids, skip_special_tokens=True)
101
+ target_words = full_text.split()
102
+
103
+ if not target_words:
104
+ return full_text, []
105
+
106
+ token_to_word_map = []
107
+ for i, token_id in enumerate(output_ids):
108
+ if i >= len(transition_scores[0]): break
109
+
110
+ log_prob = transition_scores[0][i]
111
+ prob = np.exp(log_prob)
112
+
113
+ decoded_up_to_here = tokenizer.decode(output_ids[:i+1], skip_special_tokens=True)
114
+ words_so_far = decoded_up_to_here.split()
115
+ word_index = len(words_so_far) - 1 if words_so_far else 0
116
+
117
+ token_to_word_map.append({'prob': prob, 'word_index': word_index})
118
+
119
+ word_confidences = {}
120
+ for item in token_to_word_map:
121
+ idx = item['word_index']
122
+ if idx not in word_confidences: word_confidences[idx] = []
123
+ word_confidences[idx].append(item['prob'])
124
+
125
+ confidence_list = []
126
+ for i in range(len(target_words)):
127
+ if i in word_confidences:
128
+ probs = word_confidences[i]
129
+ confidence_list.append(float(np.mean(probs)))
130
+ else:
131
+ confidence_list.append(0.0)
132
+
133
+ return full_text, confidence_list
134
+
135
+ @app.route('/')
136
+ def index():
137
+ return render_template('index.html')
138
+
139
+ @app.route('/correct', methods=['POST'])
140
+ def correct_text():
141
+ data = request.get_json()
142
+ input_text = data.get('text', '')
143
+ mode = data.get('mode', 'correction')
144
+
145
+ if not input_text.strip():
146
+ return jsonify({"result": "", "confidences": []})
147
+
148
+ try:
149
+ generated_text, confidences = process_with_confidence(input_text, mode)
150
+ return jsonify({
151
+ "result": generated_text,
152
+ "confidences": confidences
153
+ })
154
+ except Exception as e:
155
+ print(f"Error: {e}")
156
+ return jsonify({"error": str(e)}), 500
157
+
158
+ if __name__ == '__main__':
159
  app.run(host='0.0.0.0', port=7860, debug=False)