ogaith's picture
Update app.py
ade444b
import os
import io
import tempfile
from typing import List
import re
import gradio as gr
from huggingface_hub import snapshot_download
import ctranslate2
import sentencepiece as spm
import hanlp
# ====== CONFIGURE HERE ======
# Public Hugging Face repo of your CTranslate2 model
MODEL_REPO = "ogaith/zhen-ctranslate2"
# Paths (inside the model repo) to the SentencePiece models
SRC_SPM = "source.spm"
TGT_SPM = "target.spm"
# Local example file that lives in the SAME repo as this app.py
EXAMPLE_FILE = "example_corpus.txt"
# ============================
# Download the model once into the Space cache (or local cache when running locally)
MODEL_DIR = snapshot_download(MODEL_REPO)
# Load CT2 translator + SentencePiece + HanLP once at startup
translator = ctranslate2.Translator(MODEL_DIR, device="auto")
sp_src = spm.SentencePieceProcessor(os.path.join(MODEL_DIR, SRC_SPM))
sp_tgt = spm.SentencePieceProcessor(os.path.join(MODEL_DIR, TGT_SPM))
# HanLP: Chinese segmenter (adjust to your preferred HanLP pipeline if needed)
hanlp_tok = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG')
def preprocess_source(text: str) -> str:
text = text.strip()
if not text:
return ""
tokens = hanlp_tok(text)
text = "".join(tokens)
text = re.sub(r'([\u4e00-\u9fff])(\d)', r'\1 \2', text)
text = re.sub(r'(\d)([\u4e00-\u9fff])', r'\1 \2', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def translate_lines(lines: List[str], beam_size: int, max_len: int, batch_size: int) -> List[str]:
"""Translate a list of CN lines -> EN lines using CT2.
Source side: HanLP preprocessing + SentencePiece encode.
Target side: SentencePiece decode ONLY (no HanLP).
"""
out_lines = []
for i in range(0, len(lines), batch_size):
chunk = lines[i:i + batch_size]
pre = [preprocess_source(s) for s in chunk]
src_tok = [sp_src.encode(s, out_type=str) for s in pre]
results = translator.translate_batch(
src_tok,
beam_size=int(beam_size),
max_decoding_length=int(max_len),
)
for r in results:
out_lines.append(sp_tgt.decode(r.hypotheses[0]))
return out_lines
def to_temp_txt(content: str) -> str:
"""Write content to a temporary .txt file and return its path for download."""
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
tmp.write(content)
tmp.close()
return tmp.name
def run_on_uploaded(file_bytes, beam_size, max_len, batch_size):
"""Handle user-uploaded .txt (UTF-8). Returns: downloadable file, preview, status msg."""
if file_bytes is None:
return None, gr.update(value=None), "Please upload a .txt file."
try:
text = file_bytes.decode("utf-8")
except UnicodeDecodeError:
# We expect UTF-8; if not, inform the user clearly.
return None, gr.update(value=None), "Encoding error. Make sure the file is UTF-8."
lines = text.splitlines()
outs = translate_lines(lines, beam_size, max_len, batch_size)
out_txt = "\n".join(outs) + ("\n" if outs else "")
path = to_temp_txt(out_txt)
# Short preview: first 10 lines to avoid clutter
preview = "\n".join(outs[:10])
return path, preview, f"Translated {len(outs)} lines."
def run_on_example(beam_size, max_len, batch_size):
"""Translate the local example_corpus.txt that ships with this repo."""
if not os.path.exists(EXAMPLE_FILE):
return None, gr.update(value=None), f"File '{EXAMPLE_FILE}' not found in the repo."
with open(EXAMPLE_FILE, "r", encoding="utf-8") as f:
lines = f.read().splitlines()
outs = translate_lines(lines, beam_size, max_len, batch_size)
out_txt = "\n".join(outs) + ("\n" if outs else "")
path = to_temp_txt(out_txt)
preview = "\n".join(outs[:10])
return path, preview, f"Translated {len(outs)} lines from '{EXAMPLE_FILE}'."
with gr.Blocks() as demo:
gr.Markdown("# 🇨🇳→🇬🇧 TXT Translation (CTranslate2 + HanLP + SentencePiece)")
gr.Markdown(
"Upload a UTF-8 `.txt` with **one Chinese sentence per line** and download the English `.txt` output.\n\n"
f"Or click to translate the bundled **`{EXAMPLE_FILE}`** in this repository."
)
with gr.Row():
beam = gr.Slider(1, 8, value=4, step=1, label="Beam size")
max_len = gr.Slider(16, 512, value=256, step=1, label="Max decoding length")
bs = gr.Slider(1, 128, value=32, step=1, label="Batch size")
gr.Markdown("### Translate an uploaded file")
with gr.Row():
inp = gr.File(label="Upload .txt (UTF-8)", file_count="single", type="binary")
btn_upload = gr.Button("Translate uploaded file")
out_file_upload = gr.File(label="Download translations (.txt)")
out_preview_upload = gr.Textbox(label="Preview (first 10 lines)", lines=10)
status_upload = gr.Markdown()
btn_upload.click(
run_on_uploaded,
[inp, beam, max_len, bs],
[out_file_upload, out_preview_upload, status_upload],
)
gr.Markdown("---")
gr.Markdown(f"### Translate the repository example file (`{EXAMPLE_FILE}`)")
btn_example = gr.Button("Translate example_corpus.txt")
out_file_example = gr.File(label="Download example translations (.txt)")
out_preview_example = gr.Textbox(label="Example preview (first 10)", lines=10)
status_example = gr.Markdown()
btn_example.click(
run_on_example,
[beam, max_len, bs],
[out_file_example, out_preview_example, status_example],
)
# Important: Spaces call demo.launch() automatically; keeping it explicit also allows local runs.
demo.launch()