android-API / app.py
hasmar03's picture
Update app.py
f665421 verified
raw
history blame
4.04 kB
# app.py — Gradio dengan decoding yang konsisten seperti di Colab
import os
import gradio as gr
import torch
from transformers import (
AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, GenerationConfig
)
MODEL_ID = os.getenv("MODEL_ID", "hasmar03/mt5_id2md")
# >>>>>> UBAH SESUAI DATA TRAINING KAMU <<<<<<
ID2MD_PREFIX = "translate Indonesian to Mandar: "
MD2ID_PREFIX = "translate Mandar to Indonesian: "
# Jika kamu melatih pakai token lain (mis. "id2md: " / "md2id: " atau ">>md<< "),
# ganti string di atas agar 100% sama.
# Default decoding (samakan dengan Colab)
DEFAULT_DECODE = dict(
num_beams=5,
length_penalty=1.0,
no_repeat_ngram_size=3,
early_stopping=True,
max_new_tokens=128,
)
pipe = None
gen_cfg = None
def get_pipe():
global pipe, gen_cfg
if pipe is None:
tok = AutoTokenizer.from_pretrained(MODEL_ID)
mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
# Muat generation_config dari repo (jika ada)
try:
gen_cfg = GenerationConfig.from_pretrained(MODEL_ID)
mdl.generation_config = gen_cfg
except Exception:
gen_cfg = mdl.generation_config # fallback
pipe = pipeline(
"text2text-generation",
model=mdl,
tokenizer=tok,
device=0 if torch.cuda.is_available() else -1,
)
return pipe
def build_prompt(text: str, direction: str):
if direction == "Indonesia → Mandar":
return f"{ID2MD_PREFIX}{text}"
else:
return f"{MD2ID_PREFIX}{text}"
def translate(text: str, direction: str,
num_beams: int, max_new_tokens: int,
no_repeat_ngram_size: int, length_penalty: float,
do_sample: bool, temperature: float, top_p: float, top_k: int):
p = get_pipe()
prompt = build_prompt(text, direction)
# Susun argumen generate; mulai dari DEFAULT_DECODE lalu override dari UI
gen_args = dict(DEFAULT_DECODE)
gen_args.update(
num_beams=int(num_beams),
max_new_tokens=int(max_new_tokens),
no_repeat_ngram_size=int(no_repeat_ngram_size),
length_penalty=float(length_penalty),
)
if do_sample:
gen_args.update(do_sample=True, temperature=float(temperature),
top_p=float(top_p), top_k=int(top_k))
else:
gen_args.update(do_sample=False)
out = p(prompt, **gen_args)[0]["generated_text"]
return out
with gr.Blocks(title="Mandar ↔ Indonesia Translator") as demo:
gr.Markdown("### Mandar ↔ Indonesia Translator")
with gr.Row():
direction = gr.Radio(
["Indonesia → Mandar", "Mandar → Indonesia"],
value="Indonesia → Mandar", label="Arah"
)
src = gr.Textbox(label="Teks sumber", lines=3, placeholder="Ketik teks…")
btn = gr.Button("Terjemahkan")
out = gr.Textbox(label="Hasil", lines=3)
with gr.Accordion("Advanced decoding", open=False):
num_beams = gr.Slider(1, 10, value=DEFAULT_DECODE["num_beams"], step=1, label="num_beams")
max_new_tokens = gr.Slider(16, 512, value=DEFAULT_DECODE["max_new_tokens"], step=8, label="max_new_tokens")
no_repeat_ngram_size = gr.Slider(0, 10, value=DEFAULT_DECODE["no_repeat_ngram_size"], step=1, label="no_repeat_ngram_size")
length_penalty = gr.Slider(0.0, 2.0, value=DEFAULT_DECODE["length_penalty"], step=0.1, label="length_penalty")
do_sample = gr.Checkbox(False, label="Sampling (non-deterministic)")
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="top_p")
top_k = gr.Slider(0, 100, value=50, step=5, label="top_k")
btn.click(
translate,
inputs=[src, direction, num_beams, max_new_tokens, no_repeat_ngram_size,
length_penalty, do_sample, temperature, top_p, top_k],
outputs=out,
api_name="translate"
)
demo.queue()
if __name__ == "__main__":
demo.launch()