jing-ju commited on
Commit
9089433
·
verified ·
1 Parent(s): e79dba8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -118
app.py CHANGED
@@ -1,122 +1,149 @@
1
- # app.py — HF Spaces Free (CPU), Hunyuan-MT 7B-fp8, đa ngôn ngữ, chia đoạn, UI + API
2
- import os, re
3
- from typing import List, Optional
4
-
5
  import gradio as gr
6
- import torch
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
-
9
- # ===== Cấu hình =====
10
- DEFAULT_MODEL = "tencent/Hunyuan-MT-7B-fp8" # đổi bằng env MODEL_NAME nếu muốn
11
- MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL)
12
-
13
- GEN_KW = dict( # tham số sinh nhẹ cho CPU
14
- max_new_tokens=256,
15
- top_k=20,
16
- top_p=0.6,
17
- repetition_penalty=1.05,
18
- temperature=0.7,
19
- do_sample=True,
20
- )
21
-
22
- MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "800")) # giới hạn input mỗi mảnh
23
-
24
- # ===== Load tokenizer & model (fp8 bằng dict quantization_config) =====
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
26
- quant_cfg = {"quantization_method": "fp8", "ignore": []} # tránh lỗi ignore=None
27
- model = AutoModelForCausalLM.from_pretrained(
28
- MODEL_NAME,
29
- trust_remote_code=True,
30
- quantization_config=quant_cfg,
31
- )
32
- DEVICE = getattr(model, "device", torch.device("cpu"))
33
-
34
- # ===== Chuẩn hóa tên ngôn ngữ =====
35
- LANG_ALIASES = {
36
- "vi": "Vietnamese", "vie": "Vietnamese", "vietnamese": "Vietnamese", "tiếng việt": "Vietnamese",
37
- "zh": "Chinese", "chi": "Chinese", "zho": "Chinese", "chinese": "Chinese", "tiếng trung": "Chinese", "hán ngữ": "Chinese", "mandarin": "Chinese",
38
- "en": "English", "eng": "English", "tiếng anh": "English", "english": "English",
39
- "ja": "Japanese", "jpn": "Japanese", "tiếng nhật": "Japanese", "japanese": "Japanese",
40
- "ko": "Korean", "kor": "Korean", "tiếng hàn": "Korean", "korean": "Korean",
41
- "fr": "French", "fra": "French", "fre": "French", "tiếng pháp": "French", "french": "French",
42
- "de": "German", "deu": "German", "ger": "German", "tiếng đức": "German", "german": "German",
43
- "es": "Spanish", "spa": "Spanish", "tiếng tây ban nha": "Spanish", "spanish": "Spanish",
44
- "th": "Thai", "tha": "Thai", "tiếng thái": "Thai", "thai": "Thai",
45
- "id": "Indonesian", "ind": "Indonesian", "tiếng indonesia": "Indonesian", "indonesian": "Indonesian",
46
- "ms": "Malay", "msa": "Malay", "tiếng malaysia": "Malay", "malay": "Malay",
47
- "pt": "Portuguese", "por": "Portuguese", "tiếng bồ đào nha": "Portuguese", "portuguese": "Portuguese",
48
- "ru": "Russian", "rus": "Russian", "tiếng nga": "Russian", "russian": "Russian",
49
- }
50
- LANG_CHOICES = sorted(set(LANG_ALIASES.values()))
51
- def norm_lang(s: Optional[str]) -> Optional[str]:
52
- if not s: return None
53
- k = s.strip().lower()
54
- return LANG_ALIASES.get(k, s.strip())
55
-
56
- # ===== Chia văn bản theo token =====
57
- def chunk_by_tokens(text: str, max_tokens: int) -> List[str]:
58
- text = text.strip()
59
- if not text: return []
60
- rough = re.split(r"(?<=[\.!?。!?])\s+", text)
61
- chunks, buf = [], ""
62
- def tok_len(s: str) -> int:
63
- return tokenizer(s, add_special_tokens=False, return_length=True)["length"]
64
- for part in rough:
65
- cand = (buf + " " + part).strip() if buf else part
66
- if tok_len(cand) <= max_tokens:
67
- buf = cand
68
- else:
69
- if buf: chunks.append(buf); buf = ""
70
- if tok_len(part) <= max_tokens:
71
- buf = part
72
- else:
73
- ids = tokenizer(part, add_special_tokens=False)["input_ids"]
74
- for i in range(0, len(ids), max_tokens):
75
- piece = tokenizer.decode(ids[i:i+max_tokens], skip_special_tokens=True)
76
- if piece.strip(): chunks.append(piece.strip())
77
- if buf: chunks.append(buf)
78
- return [c for c in chunks if c.strip()]
79
-
80
- # ===== Core translate (chat template) =====
81
- @torch.inference_mode()
82
- def translate_text(text: str, target_lang: str, source_lang: Optional[str]=None) -> str:
83
- tgt = norm_lang(target_lang) or "Vietnamese"
84
- src = norm_lang(source_lang)
85
- sys_prompt = (f"Translate the following segment from {src} into {tgt}, without additional explanation."
86
- if src else
87
- f"Translate the following segment into {tgt}, without additional explanation.")
88
- outs = []
89
- for piece in chunk_by_tokens(text, MAX_INPUT_TOKENS):
90
- msgs = [{"role":"user","content": f"{sys_prompt}\n\n{piece}"}]
91
- inputs = tokenizer.apply_chat_template(msgs, tokenize=True, add_generation_prompt=False, return_tensors="pt")
92
- out_ids = model.generate(inputs.to(DEVICE), **GEN_KW)
93
- outs.append(tokenizer.decode(out_ids[0], skip_special_tokens=True).strip())
94
- return "\n".join(outs).strip()
95
-
96
- def translate_batch(texts: List[str], target_lang: str, source_lang: Optional[str]=None) -> List[str]:
97
- return [translate_text(t, target_lang, source_lang) for t in texts]
98
-
99
- # ===== Gradio UI + API =====
100
- with gr.Blocks() as demo:
101
- gr.Markdown("## Hunyuan-MT 7B-fp8 — Multilingual Translation (HF Free CPU)\nChia đoạn theo token, UI + API (Gradio).")
102
-
103
- with gr.Tab("Single"):
104
- src = gr.Textbox(label="Văn bản nguồn", lines=10, placeholder="Dán văn bản cần dịch…")
 
 
 
 
 
 
 
 
 
105
  with gr.Row():
106
- src_lang = gr.Textbox(label="Ngôn ngữ nguồn (tùy chọn)", placeholder="Ví dụ: Vietnamese/Chinese/English…")
107
- tgt_lang = gr.Dropdown(label="Ngôn ngữ đích", choices=LANG_CHOICES, value="Vietnamese")
108
- out = gr.Textbox(label="Bản dịch", lines=10)
109
- gr.Button("Dịch").click(translate_text, inputs=[src, tgt_lang, src_lang], outputs=out, api_name="translate_text")
 
 
 
 
110
 
111
- with gr.Tab("Batch"):
112
- src_list = gr.Textbox(label="Mỗi dòng 1 câu/đoạn", lines=10)
113
  with gr.Row():
114
- src_lang_b = gr.Textbox(label="Ngôn ngữ nguồn (tùy chọn)")
115
- tgt_lang_b = gr.Dropdown(label="Ngôn ngữ đích", choices=LANG_CHOICES, value="Vietnamese")
116
- out_list = gr.Textbox(label="Kết quả (mỗi dòng tương ứng 1 đầu vào)", lines=10)
117
- def _batch(txts_raw: str, tgt: str, src_: Optional[str]):
118
- texts = [x for x in txts_raw.splitlines() if x.strip()]
119
- return "\n".join(translate_batch(texts, tgt, src_))
120
- gr.Button("Dịch Batch").click(_batch, inputs=[src_list, tgt_lang_b, src_lang_b], outputs=out_list, api_name="translate_batch")
121
-
122
- demo.queue(concurrency_count=1, max_size=2).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
 
2
  import gradio as gr
3
+ from huggingface_hub import InferenceClient
4
+
5
+ # -------- Settings --------
6
+ DEFAULT_MODEL = os.getenv("HYMT_MODEL", "tencent/Hunyuan-MT-7B-fp8")
7
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # thể để trống (ẩn danh, sẽ bị rate-limit)
8
+
9
+ # Ngôn ngữ được model hỗ trợ (trích từ model card)
10
+ LANGS = [
11
+ ("Chinese (简体中文)", "zh"),
12
+ ("Traditional Chinese (繁體中文)", "zh-Hant"),
13
+ ("Cantonese (粤语)", "yue"),
14
+ ("English (English)", "en"),
15
+ ("Vietnamese (Tiếng Việt)", "vi"),
16
+ ("Japanese (日本語)", "ja"),
17
+ ("Korean (한국어)", "ko"),
18
+ ("Thai (ไทย)", "th"),
19
+ ("French (Français)", "fr"),
20
+ ("Spanish (Español)", "es"),
21
+ ("Portuguese (Português)", "pt"),
22
+ ("Italian (Italiano)", "it"),
23
+ ("German (Deutsch)", "de"),
24
+ ("Russian (Русский)", "ru"),
25
+ ("Arabic (العربية)", "ar"),
26
+ ("Turkish (Türkçe)", "tr"),
27
+ ("Indonesian (Bahasa Indonesia)", "id"),
28
+ ("Malay (Bahasa Melayu)", "ms"),
29
+ ("Filipino (Filipino)", "tl"),
30
+ ("Hindi (हिन्दी)", "hi"),
31
+ ("Polish (Polski)", "pl"),
32
+ ("Czech (Čeština)", "cs"),
33
+ ("Dutch (Nederlands)", "nl"),
34
+ ("Khmer (ភាសាខ្មែរ)", "km"),
35
+ ("Burmese (မြန်မာ)", "my"),
36
+ ("Persian (فارسی)", "fa"),
37
+ ("Gujarati (ગુજરાતી)", "gu"),
38
+ ("Urdu (اردو)", "ur"),
39
+ ("Telugu (తెలుగు)", "te"),
40
+ ("Marathi (मराठी)", "mr"),
41
+ ("Hebrew (עברית)", "he"),
42
+ ("Bengali (বাংলা)", "bn"),
43
+ ("Tamil (தமிழ்)", "ta"),
44
+ ("Ukrainian (Українська)", "uk"),
45
+ ("Tibetan (བོད་ཡིག)", "bo"),
46
+ ("Kazakh (Қазақша)", "kk"),
47
+ ("Mongolian (Монгол)", "mn"),
48
+ ("Uyghur (ئۇيغۇرچە)", "ug"),
49
+ ]
50
+
51
+ ZH_CODES = {"zh", "zh-Hant", "yue"}
52
+
53
+ def build_prompt(src_lang: str, tgt_lang: str, text: str) -> str:
54
+ """
55
+ Theo gợi ý prompt trong model card:
56
+ - ZH <=> XX: dùng template tiếng Trung
57
+ - XX <=> XX (không có ZH): dùng template tiếng Anh
58
+ """
59
+ if src_lang in ZH_CODES or tgt_lang in ZH_CODES:
60
+ # Template ZH <=> XX
61
+ return f"把下面的文本翻译成{tgt_lang},不要额外解释。\n\n{text.strip()}"
62
+ else:
63
+ # Template XX <=> XX (không ZH)
64
+ return f"Translate the following segment into {tgt_lang}, without additional explanation.\n\n{text.strip()}"
65
+
66
+ def call_hf_inference(model: str, prompt: str) -> str:
67
+ """
68
+ Gọi Serverless Inference API (text-generation).
69
+ Không cần GPU trên Space. Có thể dùng ẩn danh hoặc set HF_TOKEN trong Secrets.
70
+ """
71
+ client = InferenceClient(token=HF_TOKEN)
72
+ # Tham số khuyến nghị từ model card
73
+ try:
74
+ out = client.text_generation(
75
+ model=model,
76
+ prompt=prompt,
77
+ max_new_tokens=512,
78
+ temperature=0.7,
79
+ top_p=0.6,
80
+ repetition_penalty=1.05,
81
+ stream=False,
82
+ # truncate không bật để tránh cắt prompt
83
+ )
84
+ return out.strip()
85
+ except Exception as e:
86
+ return f"[Lỗi] Không thể gọi Inference API: {e}"
87
+
88
+ def translate(text: str, src: str, tgt: str, model_choice: str):
89
+ if not text or not text.strip():
90
+ return "Vui lòng nhập nội dung cần dịch."
91
+ if src == tgt:
92
+ return text.strip()
93
+
94
+ prompt = build_prompt(src, tgt, text)
95
+ # Lưu ý: Hunyuan-MT là causal LM định hướng prompt, không yêu cầu định dạng chat đặc biệt
96
+ result = call_hf_inference(model_choice, prompt)
97
+ return result
98
+
99
+ def ui():
100
+ with gr.Blocks(title="Hunyuan-MT Translation (HF Inference API)", fill_height=True) as demo:
101
+ gr.Markdown(
102
+ """
103
+ # Tencent Hunyuan-MT (Serverless)
104
+ Chạy trên **Hugging Face Space (CPU free)** bằng **Serverless Inference API**.
105
+ - Chọn mô hình `tencent/Hunyuan-MT-7B` hoặc `tencent/Hunyuan-MT-7B-fp8`.
106
+ - Chọn ngôn ngữ nguồn/đích rồi bấm **Dịch**.
107
+ > Gợi ý: vào *Settings → Repository secrets* thêm `HF_TOKEN` để tăng hạn mức.
108
+ """
109
+ )
110
+
111
  with gr.Row():
112
+ model_choice = gr.Dropdown(
113
+ choices=[
114
+ "tencent/Hunyuan-MT-7B-fp8",
115
+ "tencent/Hunyuan-MT-7B",
116
+ ],
117
+ value=DEFAULT_MODEL,
118
+ label="Model (Serverless)"
119
+ )
120
 
 
 
121
  with gr.Row():
122
+ src = gr.Dropdown(choices=[l for l, _ in LANGS], value="English (English)", label="Nguồn")
123
+ tgt = gr.Dropdown(choices=[l for l, _ in LANGS], value="Vietnamese (Tiếng Việt)", label="Đích")
124
+
125
+ # Map label -> code cho back-end
126
+ label2code = {label: code for label, code in LANGS}
127
+
128
+ def _on_translate(text, src_label, tgt_label, model_id):
129
+ src_code = label2code[src_label]
130
+ tgt_code = label2code[tgt_label]
131
+ return translate(text, src_code, tgt_code, model_id)
132
+
133
+ inp = gr.Textbox(label="Nội dung cần dịch", lines=8, placeholder="Nhập văn bản…")
134
+ btn = gr.Button("Dịch", variant="primary")
135
+ out = gr.Textbox(label="Kết quả", lines=8)
136
+
137
+ btn.click(_on_translate, [inp, src, tgt, model_choice], [out])
138
+
139
+ gr.Markdown(
140
+ """
141
+ #### Lưu ý
142
+ - Đây là demo qua **Serverless Inference API** nên tốc độ/phản hồi phụ thuộc hạn mức serverless.
143
+ - Với lượng lớn/nhanh hơn, hãy nâng cấp phần cứng (GPU) hoặc tự triển khai TGI/vLLM.
144
+ """
145
+ )
146
+ return demo
147
+
148
+ if __name__ == "__main__":
149
+ ui().launch()