Spaces:
Sleeping
Sleeping
File size: 4,037 Bytes
f665421 f9230a6 f665421 f9230a6 f156c39 f9230a6 f665421 f156c39 f665421 f156c39 f665421 f156c39 f665421 f156c39 f665421 f156c39 f9230a6 f665421 f9230a6 f665421 f156c39 f665421 f9230a6 f665421 f156c39 f9230a6 f665421 f156c39 f9230a6 f156c39 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | # app.py — Gradio dengan decoding yang konsisten seperti di Colab
import os
import gradio as gr
import torch
from transformers import (
AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, GenerationConfig
)
MODEL_ID = os.getenv("MODEL_ID", "hasmar03/mt5_id2md")
# >>>>>> UBAH SESUAI DATA TRAINING KAMU <<<<<<
ID2MD_PREFIX = "translate Indonesian to Mandar: "
MD2ID_PREFIX = "translate Mandar to Indonesian: "
# Jika kamu melatih pakai token lain (mis. "id2md: " / "md2id: " atau ">>md<< "),
# ganti string di atas agar 100% sama.
# Default decoding (samakan dengan Colab)
DEFAULT_DECODE = dict(
num_beams=5,
length_penalty=1.0,
no_repeat_ngram_size=3,
early_stopping=True,
max_new_tokens=128,
)
pipe = None
gen_cfg = None
def get_pipe():
global pipe, gen_cfg
if pipe is None:
tok = AutoTokenizer.from_pretrained(MODEL_ID)
mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
# Muat generation_config dari repo (jika ada)
try:
gen_cfg = GenerationConfig.from_pretrained(MODEL_ID)
mdl.generation_config = gen_cfg
except Exception:
gen_cfg = mdl.generation_config # fallback
pipe = pipeline(
"text2text-generation",
model=mdl,
tokenizer=tok,
device=0 if torch.cuda.is_available() else -1,
)
return pipe
def build_prompt(text: str, direction: str):
if direction == "Indonesia → Mandar":
return f"{ID2MD_PREFIX}{text}"
else:
return f"{MD2ID_PREFIX}{text}"
def translate(text: str, direction: str,
num_beams: int, max_new_tokens: int,
no_repeat_ngram_size: int, length_penalty: float,
do_sample: bool, temperature: float, top_p: float, top_k: int):
p = get_pipe()
prompt = build_prompt(text, direction)
# Susun argumen generate; mulai dari DEFAULT_DECODE lalu override dari UI
gen_args = dict(DEFAULT_DECODE)
gen_args.update(
num_beams=int(num_beams),
max_new_tokens=int(max_new_tokens),
no_repeat_ngram_size=int(no_repeat_ngram_size),
length_penalty=float(length_penalty),
)
if do_sample:
gen_args.update(do_sample=True, temperature=float(temperature),
top_p=float(top_p), top_k=int(top_k))
else:
gen_args.update(do_sample=False)
out = p(prompt, **gen_args)[0]["generated_text"]
return out
with gr.Blocks(title="Mandar ↔ Indonesia Translator") as demo:
gr.Markdown("### Mandar ↔ Indonesia Translator")
with gr.Row():
direction = gr.Radio(
["Indonesia → Mandar", "Mandar → Indonesia"],
value="Indonesia → Mandar", label="Arah"
)
src = gr.Textbox(label="Teks sumber", lines=3, placeholder="Ketik teks…")
btn = gr.Button("Terjemahkan")
out = gr.Textbox(label="Hasil", lines=3)
with gr.Accordion("Advanced decoding", open=False):
num_beams = gr.Slider(1, 10, value=DEFAULT_DECODE["num_beams"], step=1, label="num_beams")
max_new_tokens = gr.Slider(16, 512, value=DEFAULT_DECODE["max_new_tokens"], step=8, label="max_new_tokens")
no_repeat_ngram_size = gr.Slider(0, 10, value=DEFAULT_DECODE["no_repeat_ngram_size"], step=1, label="no_repeat_ngram_size")
length_penalty = gr.Slider(0.0, 2.0, value=DEFAULT_DECODE["length_penalty"], step=0.1, label="length_penalty")
do_sample = gr.Checkbox(False, label="Sampling (non-deterministic)")
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="top_p")
top_k = gr.Slider(0, 100, value=50, step=5, label="top_k")
btn.click(
translate,
inputs=[src, direction, num_beams, max_new_tokens, no_repeat_ngram_size,
length_penalty, do_sample, temperature, top_p, top_k],
outputs=out,
api_name="translate"
)
demo.queue()
if __name__ == "__main__":
demo.launch()
|