AI-Translates / app.py
jing-ju's picture
Update app.py
9ecf72c verified
raw
history blame
8.8 kB
import os
import math
import re
from typing import List, Optional
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# ✅ Import tương thích nhiều phiên bản:
try:
# Nhiều bản đặt ở đây
from transformers.quantizers import CompressedTensorsQuantizationConfig
except Exception:
try:
# Một số bản export ở root (phòng hờ)
from transformers import CompressedTensorsQuantizationConfig # type: ignore
except Exception:
CompressedTensorsQuantizationConfig = None # sẽ fallback qua dict
# =========================
# CẤU HÌNH MẶC ĐỊNH
# =========================
# Model mặc định: nhẹ hơn và phù hợp hơn cho CPU Free
DEFAULT_MODEL = "tencent/Hunyuan-MT-7B-fp8"
MODEL_NAME = os.getenv("MODEL_NAME", DEFAULT_MODEL)
# Tham số sinh gợi ý (giữ thấp để tránh quá tải CPU)
GEN_KW = dict(
max_new_tokens=256,
top_k=20,
top_p=0.6,
repetition_penalty=1.05,
temperature=0.7,
do_sample=True,
)
# Giới hạn token đầu vào mỗi lượt để tránh OOM/timeout trên CPU
# (tổng input ≲ 900–1000 token trên CPU Free cho an toàn)
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "800"))
# =========================
# TẢI MODEL & TOKENIZER
# =========================
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Ghi đè config lượng tử hóa để tránh lỗi "ignore NoneType" trên một số bản fp8
ctq = CompressedTensorsQuantizationConfig(
quantization_method="fp8",
ignore=[], # chìa khóa tránh TypeError: 'NoneType' object is not iterable
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
quantization_config=ctq,
)
DEVICE = getattr(model, "device", torch.device("cpu"))
# =========================
# TIỆN ÍCH CHUẨN HÓA NGÔN NGỮ
# =========================
# Map tên ngôn ngữ phổ biến -> tên tiếng Anh để nhúng vào prompt (đơn giản hóa)
LANG_ALIASES = {
# Vietnamese
"vi": "Vietnamese", "vie": "Vietnamese",
"vietnamese": "Vietnamese", "tiếng việt": "Vietnamese",
# Chinese
"zh": "Chinese", "chi": "Chinese", "zho": "Chinese",
"chinese": "Chinese", "tiếng trung": "Chinese", "hán ngữ": "Chinese",
"mandarin": "Chinese",
# English
"en": "English", "eng": "English", "tiếng anh": "English", "english": "English",
# Japanese
"ja": "Japanese", "jpn": "Japanese", "tiếng nhật": "Japanese", "japanese": "Japanese",
# Korean
"ko": "Korean", "kor": "Korean", "tiếng hàn": "Korean", "korean": "Korean",
# French
"fr": "French", "fra": "French", "fre": "French", "tiếng pháp": "French", "french": "French",
# German
"de": "German", "deu": "German", "ger": "German", "tiếng đức": "German", "german": "German",
# Spanish
"es": "Spanish", "spa": "Spanish", "tiếng tây ban nha": "Spanish", "spanish": "Spanish",
# Thai
"th": "Thai", "tha": "Thai", "tiếng thái": "Thai", "thai": "Thai",
# Indonesian
"id": "Indonesian", "ind": "Indonesian", "tiếng indonesia": "Indonesian", "indonesian": "Indonesian",
# Malay
"ms": "Malay", "msa": "Malay", "tiếng malaysia": "Malay", "malay": "Malay",
# Portuguese
"pt": "Portuguese", "por": "Portuguese", "tiếng bồ đào nha": "Portuguese", "portuguese": "Portuguese",
# Russian
"ru": "Russian", "rus": "Russian", "tiếng nga": "Russian", "russian": "Russian",
}
def normalize_lang_name(s: Optional[str]) -> Optional[str]:
if not s:
return None
key = s.strip().lower()
return LANG_ALIASES.get(key, s.strip())
# =========================
# CHIA ĐOẠN THEO TOKEN
# =========================
def chunk_text_by_tokens(text: str, max_tokens: int) -> List[str]:
"""
Chia văn bản thành các đoạn dựa vào số token của tokenizer để tránh vượt ngưỡng input.
Ưu tiên cắt theo dấu câu. Nếu đoạn vẫn dài, cắt tiếp theo token.
"""
# Tách theo các dấu câu lớn trước
rough_parts = re.split(r"(?<=[\.!?。!?])\s+", text.strip())
chunks = []
buf = ""
def token_len(s: str) -> int:
return tokenizer(s, add_special_tokens=False, return_length=True)["length"]
for part in rough_parts:
candidate = (buf + " " + part).strip() if buf else part
if token_len(candidate) <= max_tokens:
buf = candidate
else:
if buf:
chunks.append(buf)
buf = ""
# Nếu part tự thân đã quá dài, cắt tiếp theo token
if token_len(part) <= max_tokens:
buf = part
else:
# Cắt theo token “cứng”
ids = tokenizer(part, add_special_tokens=False)["input_ids"]
for i in range(0, len(ids), max_tokens):
piece_ids = ids[i:i + max_tokens]
piece = tokenizer.decode(piece_ids, skip_special_tokens=True)
chunks.append(piece)
buf = ""
if buf:
chunks.append(buf)
# Loại bỏ rỗng
return [c for c in chunks if c.strip()]
# =========================
# CORE TRANSLATION (SỬ DỤNG CHAT TEMPLATE)
# =========================
@torch.inference_mode()
def translate_text(
text: str,
target_lang: str,
source_lang: Optional[str] = None,
) -> str:
target = normalize_lang_name(target_lang) or "Vietnamese"
src = normalize_lang_name(source_lang)
# Xây prompt: có thể thêm nguồn nếu người dùng cung cấp, còn không để model tự đoán
if src:
sys_prompt = f"Translate the following segment from {src} into {target}, without additional explanation."
else:
sys_prompt = f"Translate the following segment into {target}, without additional explanation."
pieces = chunk_text_by_tokens(text, MAX_INPUT_TOKENS)
outputs = []
for piece in pieces:
messages = [{"role": "user", "content": f"{sys_prompt}\n\n{piece}"}]
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=False, return_tensors="pt"
)
out_ids = model.generate(inputs.to(DEVICE), **GEN_KW)
out_text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
outputs.append(out_text.strip())
return "\n".join(outputs).strip()
def translate_batch(
texts: List[str],
target_lang: str,
source_lang: Optional[str] = None,
) -> List[str]:
return [translate_text(t, target_lang, source_lang) for t in texts]
# =========================
# GRADIO UI + API
# =========================
LANG_CHOICES = sorted(list(set(LANG_ALIASES.values())))
with gr.Blocks() as demo:
gr.Markdown(
"## Hunyuan-MT (fp8) — Multilingual Translation (Trial on CPU)\n"
"Bản HF Spaces Free (CPU) — tốc độ chậm, đã có chia đoạn tự động theo token."
)
with gr.Tab("Single"):
src = gr.Textbox(label="Văn bản nguồn", lines=10, placeholder="Dán văn bản cần dịch…")
with gr.Row():
src_lang = gr.Textbox(label="Ngôn ngữ nguồn (tùy chọn, ví dụ: Vietnamese/Chinese/English…)", placeholder="Để trống nếu không chắc")
tgt_lang = gr.Dropdown(label="Ngôn ngữ đích", choices=LANG_CHOICES, value="Vietnamese")
out = gr.Textbox(label="Bản dịch", lines=10)
btn = gr.Button("Dịch")
btn.click(fn=translate_text, inputs=[src, tgt_lang, src_lang], outputs=out, api_name="translate_text")
with gr.Tab("Batch"):
src_list = gr.Textbox(
label="Danh sách câu (mỗi dòng 1 câu/đoạn ngắn)",
lines=10,
placeholder="Mỗi dòng là một câu/đoạn…"
)
with gr.Row():
src_lang_b = gr.Textbox(label="Ngôn ngữ nguồn (tuỳ chọn)", placeholder="Để trống nếu không chắc")
tgt_lang_b = gr.Dropdown(label="Ngôn ngữ đích", choices=LANG_CHOICES, value="Vietnamese")
out_list = gr.Textbox(label="Kết quả (mỗi dòng tương ứng 1 đầu vào)", lines=10)
def _batch_wrapper(texts_raw: str, tgt: str, src_: Optional[str]):
texts = [x for x in texts_raw.splitlines() if x.strip()]
results = translate_batch(texts, tgt, src_)
return "\n".join(results)
btn_b = gr.Button("Dịch Batch")
btn_b.click(fn=_batch_wrapper, inputs=[src_list, tgt_lang_b, src_lang_b], outputs=out_list, api_name="translate_batch")
# Giới hạn tải cho demo
demo.queue(concurrency_count=1, max_size=2).launch()