jing-ju commited on
Commit
7a80146
·
verified ·
1 Parent(s): 9ecf72c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -195
app.py CHANGED
@@ -1,32 +1,13 @@
1
- import os
2
- import math
3
- import re
4
- from typing import List, Optional
5
-
6
- import gradio as gr
7
  import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
9
 
10
- # Import tương thích nhiều phiên bản:
11
- try:
12
- # Nhiều bản đặt ở đây
13
- from transformers.quantizers import CompressedTensorsQuantizationConfig
14
- except Exception:
15
- try:
16
- # Một số bản export ở root (phòng hờ)
17
- from transformers import CompressedTensorsQuantizationConfig # type: ignore
18
- except Exception:
19
- CompressedTensorsQuantizationConfig = None # sẽ fallback qua dict
20
-
21
-
22
- # =========================
23
- # CẤU HÌNH MẶC ĐỊNH
24
- # =========================
25
- # Model mặc định: nhẹ hơn và phù hợp hơn cho CPU Free
26
- DEFAULT_MODEL = "tencent/Hunyuan-MT-7B-fp8"
27
- MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL)
28
 
29
- # Tham số sinh gợi ý (giữ thấp để tránh quá tải CPU)
30
  GEN_KW = dict(
31
  max_new_tokens=256,
32
  top_k=20,
@@ -36,188 +17,325 @@ GEN_KW = dict(
36
  do_sample=True,
37
  )
38
 
39
- # Giới hạn token đầu vào mỗi lượt để tránh OOM/timeout trên CPU
40
- # (tổng input ≲ 900–1000 token trên CPU Free cho an toàn)
41
- MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "800"))
42
-
43
- # =========================
44
- # TẢI MODEL & TOKENIZER
45
- # =========================
46
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
47
-
48
- # Ghi đè config lượng tử hóa để tránh lỗi "ignore NoneType" trên một số bản fp8
49
- ctq = CompressedTensorsQuantizationConfig(
50
- quantization_method="fp8",
51
- ignore=[], # chìa khóa tránh TypeError: 'NoneType' object is not iterable
52
- )
53
-
54
- model = AutoModelForCausalLM.from_pretrained(
55
- MODEL_NAME,
56
- trust_remote_code=True,
57
- quantization_config=ctq,
58
- )
59
- DEVICE = getattr(model, "device", torch.device("cpu"))
60
-
61
- # =========================
62
- # TIỆN ÍCH CHUẨN HÓA NGÔN NGỮ
63
- # =========================
64
- # Map tên ngôn ngữ phổ biến -> tên tiếng Anh để nhúng vào prompt (đơn giản hóa)
65
- LANG_ALIASES = {
66
- # Vietnamese
67
- "vi": "Vietnamese", "vie": "Vietnamese",
68
- "vietnamese": "Vietnamese", "tiếng việt": "Vietnamese",
69
- # Chinese
70
- "zh": "Chinese", "chi": "Chinese", "zho": "Chinese",
71
- "chinese": "Chinese", "tiếng trung": "Chinese", "hán ngữ": "Chinese",
72
- "mandarin": "Chinese",
73
- # English
74
- "en": "English", "eng": "English", "tiếng anh": "English", "english": "English",
75
- # Japanese
76
- "ja": "Japanese", "jpn": "Japanese", "tiếng nhật": "Japanese", "japanese": "Japanese",
77
- # Korean
78
- "ko": "Korean", "kor": "Korean", "tiếng hàn": "Korean", "korean": "Korean",
79
- # French
80
- "fr": "French", "fra": "French", "fre": "French", "tiếng pháp": "French", "french": "French",
81
- # German
82
- "de": "German", "deu": "German", "ger": "German", "tiếng đức": "German", "german": "German",
83
- # Spanish
84
- "es": "Spanish", "spa": "Spanish", "tiếng tây ban nha": "Spanish", "spanish": "Spanish",
85
- # Thai
86
- "th": "Thai", "tha": "Thai", "tiếng thái": "Thai", "thai": "Thai",
87
- # Indonesian
88
- "id": "Indonesian", "ind": "Indonesian", "tiếng indonesia": "Indonesian", "indonesian": "Indonesian",
89
- # Malay
90
- "ms": "Malay", "msa": "Malay", "tiếng malaysia": "Malay", "malay": "Malay",
91
- # Portuguese
92
- "pt": "Portuguese", "por": "Portuguese", "tiếng bồ đào nha": "Portuguese", "portuguese": "Portuguese",
93
- # Russian
94
- "ru": "Russian", "rus": "Russian", "tiếng nga": "Russian", "russian": "Russian",
95
  }
96
 
97
- def normalize_lang_name(s: Optional[str]) -> Optional[str]:
98
- if not s:
99
- return None
100
- key = s.strip().lower()
101
- return LANG_ALIASES.get(key, s.strip())
102
 
103
- # =========================
104
- # CHIA ĐOẠN THEO TOKEN
105
- # =========================
106
- def chunk_text_by_tokens(text: str, max_tokens: int) -> List[str]:
107
- """
108
- Chia văn bản thành các đoạn dựa vào số token của tokenizer để tránh vượt ngưỡng input.
109
- Ưu tiên cắt theo dấu câu. Nếu đoạn vẫn dài, cắt tiếp theo token.
110
- """
111
- # Tách theo các dấu câu lớn trước
112
- rough_parts = re.split(r"(?<=[\.!?。!?])\s+", text.strip())
113
- chunks = []
114
- buf = ""
115
 
116
- def token_len(s: str) -> int:
117
- return tokenizer(s, add_special_tokens=False, return_length=True)["length"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- for part in rough_parts:
120
- candidate = (buf + " " + part).strip() if buf else part
121
- if token_len(candidate) <= max_tokens:
122
- buf = candidate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  else:
124
- if buf:
125
- chunks.append(buf)
126
- buf = ""
127
-
128
- # Nếu part tự thân đã quá dài, cắt tiếp theo token
129
- if token_len(part) <= max_tokens:
130
- buf = part
 
 
 
 
131
  else:
132
- # Cắt theo token “cứng”
133
- ids = tokenizer(part, add_special_tokens=False)["input_ids"]
134
- for i in range(0, len(ids), max_tokens):
135
- piece_ids = ids[i:i + max_tokens]
136
- piece = tokenizer.decode(piece_ids, skip_special_tokens=True)
137
- chunks.append(piece)
138
- buf = ""
139
-
140
- if buf:
141
- chunks.append(buf)
142
-
143
- # Loại bỏ rỗng
144
- return [c for c in chunks if c.strip()]
145
-
146
- # =========================
147
- # CORE TRANSLATION (SỬ DỤNG CHAT TEMPLATE)
148
- # =========================
149
- @torch.inference_mode()
150
- def translate_text(
151
- text: str,
152
- target_lang: str,
153
- source_lang: Optional[str] = None,
154
- ) -> str:
155
- target = normalize_lang_name(target_lang) or "Vietnamese"
156
- src = normalize_lang_name(source_lang)
157
 
158
- # Xây prompt: thể thêm nguồn nếu người dùng cung cấp, còn không để model tự đoán
159
- if src:
160
- sys_prompt = f"Translate the following segment from {src} into {target}, without additional explanation."
 
 
 
 
 
 
 
 
161
  else:
162
- sys_prompt = f"Translate the following segment into {target}, without additional explanation."
163
-
164
- pieces = chunk_text_by_tokens(text, MAX_INPUT_TOKENS)
165
- outputs = []
166
-
167
- for piece in pieces:
168
- messages = [{"role": "user", "content": f"{sys_prompt}\n\n{piece}"}]
169
- inputs = tokenizer.apply_chat_template(
170
- messages, tokenize=True, add_generation_prompt=False, return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
171
  )
172
- out_ids = model.generate(inputs.to(DEVICE), **GEN_KW)
173
- out_text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
174
- outputs.append(out_text.strip())
175
-
176
- return "\n".join(outputs).strip()
177
-
178
- def translate_batch(
179
- texts: List[str],
180
- target_lang: str,
181
- source_lang: Optional[str] = None,
182
- ) -> List[str]:
183
- return [translate_text(t, target_lang, source_lang) for t in texts]
184
 
185
- # =========================
186
- # GRADIO UI + API
187
- # =========================
188
- LANG_CHOICES = sorted(list(set(LANG_ALIASES.values())))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- with gr.Blocks() as demo:
191
- gr.Markdown(
192
- "## Hunyuan-MT (fp8) — Multilingual Translation (Trial on CPU)\n"
193
- "Bản HF Spaces Free (CPU) — tốc độ chậm, đã có chia đoạn tự động theo token."
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- with gr.Tab("Single"):
197
- src = gr.Textbox(label="Văn bản nguồn", lines=10, placeholder="Dán văn bản cần dịch…")
198
- with gr.Row():
199
- src_lang = gr.Textbox(label="Ngôn ngữ nguồn (tùy chọn, ví dụ: Vietnamese/Chinese/English…)", placeholder="Để trống nếu không chắc")
200
- tgt_lang = gr.Dropdown(label="Ngôn ngữ đích", choices=LANG_CHOICES, value="Vietnamese")
201
- out = gr.Textbox(label="Bản dịch", lines=10)
202
- btn = gr.Button("Dịch")
203
- btn.click(fn=translate_text, inputs=[src, tgt_lang, src_lang], outputs=out, api_name="translate_text")
204
 
205
- with gr.Tab("Batch"):
206
- src_list = gr.Textbox(
207
- label="Danh sách câu (mỗi dòng 1 câu/đoạn ngắn)",
208
- lines=10,
209
- placeholder="Mỗi dòng một câu/đoạn…"
210
- )
211
- with gr.Row():
212
- src_lang_b = gr.Textbox(label="Ngôn ngữ nguồn (tuỳ chọn)", placeholder="Để trống nếu không chắc")
213
- tgt_lang_b = gr.Dropdown(label="Ngôn ngữ đích", choices=LANG_CHOICES, value="Vietnamese")
214
- out_list = gr.Textbox(label="Kết quả (mỗi dòng tương ứng 1 đầu vào)", lines=10)
215
- def _batch_wrapper(texts_raw: str, tgt: str, src_: Optional[str]):
216
- texts = [x for x in texts_raw.splitlines() if x.strip()]
217
- results = translate_batch(texts, tgt, src_)
218
- return "\n".join(results)
219
- btn_b = gr.Button("Dịch Batch")
220
- btn_b.click(fn=_batch_wrapper, inputs=[src_list, tgt_lang_b, src_lang_b], outputs=out_list, api_name="translate_batch")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- # Giới hạn tải cho demo
223
- demo.queue(concurrency_count=1, max_size=2).launch()
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import gradio as gr
4
+ import re
5
 
6
+ # Environment variables
7
+ MODEL_NAME = os.getenv("MODEL_NAME", "tencent/Hunyuan-MT-7B-fp8")
8
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "800"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Generation parameters optimized for CPU
11
  GEN_KW = dict(
12
  max_new_tokens=256,
13
  top_k=20,
 
17
  do_sample=True,
18
  )
19
 
20
+ # Language mapping for normalization
21
+ LANGUAGE_MAPPING = {
22
+ "vi": "Vietnamese",
23
+ "vietnamese": "Vietnamese",
24
+ "tiếng việt": "Vietnamese",
25
+ "zh": "Chinese",
26
+ "chinese": "Chinese",
27
+ "tiếng trung": "Chinese",
28
+ "中文": "Chinese",
29
+ "en": "English",
30
+ "english": "English",
31
+ "tiếng anh": "English",
32
+ "ja": "Japanese",
33
+ "japanese": "Japanese",
34
+ "tiếng nhật": "Japanese",
35
+ "日本語": "Japanese",
36
+ "ko": "Korean",
37
+ "korean": "Korean",
38
+ "tiếng hàn": "Korean",
39
+ "한국어": "Korean",
40
+ "fr": "French",
41
+ "french": "French",
42
+ "tiếng pháp": "French",
43
+ "de": "German",
44
+ "german": "German",
45
+ "tiếng đức": "German",
46
+ "es": "Spanish",
47
+ "spanish": "Spanish",
48
+ "tiếng tây ban nha": "Spanish",
49
+ "th": "Thai",
50
+ "thai": "Thai",
51
+ "tiếng thái": "Thai",
52
+ "id": "Indonesian",
53
+ "indonesian": "Indonesian",
54
+ "tiếng indonesia": "Indonesian",
55
+ "ms": "Malay",
56
+ "malay": "Malay",
57
+ "tiếng malaysia": "Malay",
58
+ "pt": "Portuguese",
59
+ "portuguese": "Portuguese",
60
+ "tiếng bồ đào nha": "Portuguese",
61
+ "ru": "Russian",
62
+ "russian": "Russian",
63
+ "tiếng nga": "Russian",
 
 
 
 
 
 
 
 
 
 
 
 
64
  }
65
 
66
+ SUPPORTED_LANGUAGES = [
67
+ "Vietnamese", "Chinese", "English", "Japanese", "Korean",
68
+ "French", "German", "Spanish", "Thai", "Indonesian",
69
+ "Malay", "Portuguese", "Russian"
70
+ ]
71
 
72
+ def normalize_language(lang):
73
+ """Normalize language name"""
74
+ if not lang:
75
+ return None
76
+ lang_lower = lang.strip().lower()
77
+ return LANGUAGE_MAPPING.get(lang_lower, lang.strip())
 
 
 
 
 
 
78
 
79
+ def load_model():
80
+ """Load model and tokenizer with fp8 quantization config"""
81
+ print(f"Loading model: {MODEL_NAME}")
82
+
83
+ # Load tokenizer
84
+ tokenizer = AutoTokenizer.from_pretrained(
85
+ MODEL_NAME,
86
+ trust_remote_code=True
87
+ )
88
+
89
+ # Create quantization config for fp8
90
+ try:
91
+ from transformers.quantizers import CompressedTensorsQuantizationConfig
92
+ quantization_config = CompressedTensorsQuantizationConfig(
93
+ quantization_method="fp8",
94
+ ignore=[]
95
+ )
96
+ except ImportError:
97
+ # Fallback to dict format
98
+ quantization_config = {
99
+ "quantization_method": "fp8",
100
+ "ignore": []
101
+ }
102
+
103
+ # Load model with quantization config
104
+ model = AutoModelForCausalLM.from_pretrained(
105
+ MODEL_NAME,
106
+ trust_remote_code=True,
107
+ quantization_config=quantization_config,
108
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
109
+ )
110
+
111
+ return tokenizer, model
112
 
113
+ def chunk_text_by_tokens(text, tokenizer, max_tokens):
114
+ """Split text into chunks based on token count"""
115
+ if not text.strip():
116
+ return []
117
+
118
+ # First, try splitting by sentence delimiters
119
+ sentences = re.split(r'[.!?。!?]', text)
120
+ chunks = []
121
+ current_chunk = ""
122
+
123
+ for sentence in sentences:
124
+ sentence = sentence.strip()
125
+ if not sentence:
126
+ continue
127
+
128
+ test_chunk = current_chunk + " " + sentence if current_chunk else sentence
129
+
130
+ # Estimate token length
131
+ try:
132
+ token_count = len(tokenizer.encode(test_chunk, add_special_tokens=False))
133
+ except:
134
+ token_count = len(test_chunk.split()) * 1.3 # rough estimation
135
+
136
+ if token_count <= max_tokens:
137
+ current_chunk = test_chunk
138
  else:
139
+ if current_chunk:
140
+ chunks.append(current_chunk.strip())
141
+
142
+ # If single sentence is too long, split it forcefully
143
+ if len(tokenizer.encode(sentence, add_special_tokens=False)) > max_tokens:
144
+ tokens = tokenizer.encode(sentence, add_special_tokens=False)
145
+ for i in range(0, len(tokens), max_tokens):
146
+ chunk_tokens = tokens[i:i + max_tokens]
147
+ chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
148
+ chunks.append(chunk_text)
149
+ current_chunk = ""
150
  else:
151
+ current_chunk = sentence
152
+
153
+ if current_chunk:
154
+ chunks.append(current_chunk.strip())
155
+
156
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ def translate_text_chunk(text, target_lang, source_lang, tokenizer, model):
159
+ """Translate a single chunk of text"""
160
+ target_lang = normalize_language(target_lang)
161
+ source_lang = normalize_language(source_lang) if source_lang else None
162
+
163
+ if not target_lang:
164
+ return "Error: Invalid target language"
165
+
166
+ # Create prompt
167
+ if source_lang:
168
+ prompt = f"Translate the following segment from {source_lang} into {target_lang}, without additional explanation.\n\n{text}"
169
  else:
170
+ prompt = f"Translate the following segment into {target_lang}, without additional explanation.\n\n{text}"
171
+
172
+ # Apply chat template
173
+ messages = [{"role": "user", "content": prompt}]
174
+ input_text = tokenizer.apply_chat_template(
175
+ messages,
176
+ tokenize=False,
177
+ add_generation_prompt=True
178
+ )
179
+
180
+ # Tokenize
181
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
182
+
183
+ # Generate
184
+ with torch.no_grad():
185
+ outputs = model.generate(
186
+ **inputs,
187
+ **GEN_KW,
188
+ pad_token_id=tokenizer.eos_token_id
189
  )
190
+
191
+ # Decode
192
+ response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
193
+ return response.strip()
 
 
 
 
 
 
 
 
194
 
195
+ def translate_single(text, target_lang, source_lang, tokenizer, model):
196
+ """Translate text with automatic chunking"""
197
+ if not text.strip():
198
+ return "Please enter text to translate."
199
+
200
+ if not target_lang:
201
+ return "Please select a target language."
202
+
203
+ try:
204
+ # Split into chunks
205
+ chunks = chunk_text_by_tokens(text, tokenizer, MAX_INPUT_TOKENS)
206
+
207
+ if not chunks:
208
+ return "No valid text to translate."
209
+
210
+ # Translate each chunk
211
+ translations = []
212
+ for chunk in chunks:
213
+ translation = translate_text_chunk(chunk, target_lang, source_lang, tokenizer, model)
214
+ translations.append(translation)
215
+
216
+ return " ".join(translations)
217
+
218
+ except Exception as e:
219
+ return f"Translation error: {str(e)}"
220
 
221
+ def translate_batch(text_lines, target_lang, source_lang, tokenizer, model):
222
+ """Translate multiple lines of text"""
223
+ if not text_lines.strip():
224
+ return "Please enter text lines to translate."
225
+
226
+ if not target_lang:
227
+ return "Please select a target language."
228
+
229
+ lines = [line.strip() for line in text_lines.split('\n') if line.strip()]
230
+
231
+ if not lines:
232
+ return "No valid text lines to translate."
233
+
234
+ try:
235
+ results = []
236
+ for line in lines:
237
+ translation = translate_single(line, target_lang, source_lang, tokenizer, model)
238
+ results.append(translation)
239
+
240
+ return '\n'.join(results)
241
+
242
+ except Exception as e:
243
+ return f"Batch translation error: {str(e)}"
244
 
245
+ # Load model and tokenizer
246
+ print("Initializing model...")
247
+ tokenizer, model = load_model()
248
+ device = model.device
249
+ print(f"Model loaded on device: {device}")
 
 
 
250
 
251
+ # Create Gradio interface
252
+ with gr.Blocks(title="Hunyuan-MT Multi-language Translation") as demo:
253
+ gr.Markdown("# 🌍 Hunyuan-MT Multi-language Translation")
254
+ gr.Markdown(f"**Model**: {MODEL_NAME}")
255
+ gr.Markdown("⚠️ **Note**: Running on Free CPU - translation may be slow and length is limited.")
256
+
257
+ with gr.Tabs():
258
+ with gr.TabItem("Single Translation"):
259
+ with gr.Row():
260
+ with gr.Column():
261
+ input_text = gr.Textbox(
262
+ label="Text to translate",
263
+ placeholder="Enter your text here...",
264
+ lines=5
265
+ )
266
+ target_lang = gr.Dropdown(
267
+ choices=SUPPORTED_LANGUAGES,
268
+ label="Target Language",
269
+ value="Vietnamese"
270
+ )
271
+ source_lang = gr.Textbox(
272
+ label="Source Language (optional)",
273
+ placeholder="Leave empty for auto-detection"
274
+ )
275
+ translate_btn = gr.Button("Translate", variant="primary")
276
+
277
+ with gr.Column():
278
+ output_text = gr.Textbox(
279
+ label="Translation",
280
+ lines=5,
281
+ interactive=False
282
+ )
283
+
284
+ translate_btn.click(
285
+ fn=lambda text, tgt, src: translate_single(text, tgt, src, tokenizer, model),
286
+ inputs=[input_text, target_lang, source_lang],
287
+ outputs=output_text,
288
+ api_name="translate_text"
289
+ )
290
+
291
+ with gr.TabItem("Batch Translation"):
292
+ with gr.Row():
293
+ with gr.Column():
294
+ batch_input = gr.Textbox(
295
+ label="Text lines to translate (one per line)",
296
+ placeholder="Line 1\nLine 2\nLine 3...",
297
+ lines=8
298
+ )
299
+ batch_target_lang = gr.Dropdown(
300
+ choices=SUPPORTED_LANGUAGES,
301
+ label="Target Language",
302
+ value="Vietnamese"
303
+ )
304
+ batch_source_lang = gr.Textbox(
305
+ label="Source Language (optional)",
306
+ placeholder="Leave empty for auto-detection"
307
+ )
308
+ batch_translate_btn = gr.Button("Translate Batch", variant="primary")
309
+
310
+ with gr.Column():
311
+ batch_output = gr.Textbox(
312
+ label="Batch Translation Results",
313
+ lines=8,
314
+ interactive=False
315
+ )
316
+
317
+ batch_translate_btn.click(
318
+ fn=lambda text, tgt, src: translate_batch(text, tgt, src, tokenizer, model),
319
+ inputs=[batch_input, batch_target_lang, batch_source_lang],
320
+ outputs=batch_output,
321
+ api_name="translate_batch"
322
+ )
323
+
324
+ gr.Markdown("### API Usage")
325
+ gr.Markdown("""
326
+ ```python
327
+ from gradio_client import Client
328
+
329
+ client = Client("YOUR_SPACE_URL")
330
+
331
+ # Single translation
332
+ result = client.predict("你好", "Vietnamese", None, api_name="/translate_text")
333
+
334
+ # Batch translation
335
+ result = client.predict("你好\\n再见", "Vietnamese", None, api_name="/translate_batch")
336
+ ```
337
+ """)
338
 
339
+ # Launch the app
340
+ if __name__ == "__main__":
341
+ demo.queue(concurrency_count=1, max_size=2).launch()