hasmar03 commited on
Commit
f665421
·
verified ·
1 Parent(s): 1f1b115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -26
app.py CHANGED
@@ -1,57 +1,112 @@
1
- # app.py — Gradio Blocks + REST API bawaan (api_name), lazy-load model
2
  import os
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
 
5
 
6
- # ===== Konfigurasi =====
7
  MODEL_ID = os.getenv("MODEL_ID", "hasmar03/mt5_id2md")
8
- MAX_LEN = int(os.getenv("MAX_LEN", "128"))
9
 
10
- # ===== Lazy loader =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  pipe = None
 
 
12
  def get_pipe():
13
- global pipe
14
  if pipe is None:
15
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
16
  mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
17
  pipe = pipeline(
18
  "text2text-generation",
19
  model=mdl,
20
  tokenizer=tok,
21
- max_length=MAX_LEN,
22
  )
23
  return pipe
24
 
25
- def _build_prompt(text: str, direction: str):
26
- # Sesuaikan dengan skema training Anda
27
- if direction == "id2md" or direction == "Indonesia → Mandar":
28
- return f"translate Indonesian to Mandar: {text}"
29
- elif direction == "md2id" or direction == "Mandar → Indonesia":
30
- return f"translate Mandar to Indonesian: {text}"
31
- return text
32
 
33
- def translate_fn(text: str, arah: str):
 
 
 
34
  p = get_pipe()
35
- prompt = _build_prompt(text, arah)
36
- out = p(prompt)[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return out
38
 
39
  with gr.Blocks(title="Mandar ↔ Indonesia Translator") as demo:
40
  gr.Markdown("### Mandar ↔ Indonesia Translator")
41
- arah = gr.Radio(
42
- ["Indonesia Mandar", "Mandar → Indonesia"],
43
- value="Indonesia → Mandar",
44
- label="Arah",
45
- )
46
  src = gr.Textbox(label="Teks sumber", lines=3, placeholder="Ketik teks…")
47
  btn = gr.Button("Terjemahkan")
48
  out = gr.Textbox(label="Hasil", lines=3)
49
- # api_name membuat REST endpoint: /api/predict/translate
50
- btn.click(translate_fn, inputs=[src, arah], outputs=out, api_name="translate")
51
 
52
- # Antrian (aman untuk Space)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  demo.queue()
54
 
55
- # Opsional: saat run lokal
56
  if __name__ == "__main__":
57
  demo.launch()
 
1
+ # app.py — Gradio dengan decoding yang konsisten seperti di Colab
2
  import os
3
  import gradio as gr
4
+ import torch
5
+ from transformers import (
6
+ AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, GenerationConfig
7
+ )
8
 
 
9
  MODEL_ID = os.getenv("MODEL_ID", "hasmar03/mt5_id2md")
 
10
 
11
+ # >>>>>> UBAH SESUAI DATA TRAINING KAMU <<<<<<
12
+ ID2MD_PREFIX = "translate Indonesian to Mandar: "
13
+ MD2ID_PREFIX = "translate Mandar to Indonesian: "
14
+ # Jika kamu melatih pakai token lain (mis. "id2md: " / "md2id: " atau ">>md<< "),
15
+ # ganti string di atas agar 100% sama.
16
+
17
+ # Default decoding (samakan dengan Colab)
18
+ DEFAULT_DECODE = dict(
19
+ num_beams=5,
20
+ length_penalty=1.0,
21
+ no_repeat_ngram_size=3,
22
+ early_stopping=True,
23
+ max_new_tokens=128,
24
+ )
25
+
26
  pipe = None
27
+ gen_cfg = None
28
+
29
  def get_pipe():
30
+ global pipe, gen_cfg
31
  if pipe is None:
32
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
33
  mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
34
+
35
+ # Muat generation_config dari repo (jika ada)
36
+ try:
37
+ gen_cfg = GenerationConfig.from_pretrained(MODEL_ID)
38
+ mdl.generation_config = gen_cfg
39
+ except Exception:
40
+ gen_cfg = mdl.generation_config # fallback
41
+
42
  pipe = pipeline(
43
  "text2text-generation",
44
  model=mdl,
45
  tokenizer=tok,
46
+ device=0 if torch.cuda.is_available() else -1,
47
  )
48
  return pipe
49
 
50
+ def build_prompt(text: str, direction: str):
51
+ if direction == "Indonesia Mandar":
52
+ return f"{ID2MD_PREFIX}{text}"
53
+ else:
54
+ return f"{MD2ID_PREFIX}{text}"
 
 
55
 
56
+ def translate(text: str, direction: str,
57
+ num_beams: int, max_new_tokens: int,
58
+ no_repeat_ngram_size: int, length_penalty: float,
59
+ do_sample: bool, temperature: float, top_p: float, top_k: int):
60
  p = get_pipe()
61
+ prompt = build_prompt(text, direction)
62
+
63
+ # Susun argumen generate; mulai dari DEFAULT_DECODE lalu override dari UI
64
+ gen_args = dict(DEFAULT_DECODE)
65
+ gen_args.update(
66
+ num_beams=int(num_beams),
67
+ max_new_tokens=int(max_new_tokens),
68
+ no_repeat_ngram_size=int(no_repeat_ngram_size),
69
+ length_penalty=float(length_penalty),
70
+ )
71
+ if do_sample:
72
+ gen_args.update(do_sample=True, temperature=float(temperature),
73
+ top_p=float(top_p), top_k=int(top_k))
74
+ else:
75
+ gen_args.update(do_sample=False)
76
+
77
+ out = p(prompt, **gen_args)[0]["generated_text"]
78
  return out
79
 
80
  with gr.Blocks(title="Mandar ↔ Indonesia Translator") as demo:
81
  gr.Markdown("### Mandar ↔ Indonesia Translator")
82
+ with gr.Row():
83
+ direction = gr.Radio(
84
+ ["Indonesia → Mandar", "Mandar → Indonesia"],
85
+ value="Indonesia → Mandar", label="Arah"
86
+ )
87
  src = gr.Textbox(label="Teks sumber", lines=3, placeholder="Ketik teks…")
88
  btn = gr.Button("Terjemahkan")
89
  out = gr.Textbox(label="Hasil", lines=3)
 
 
90
 
91
+ with gr.Accordion("Advanced decoding", open=False):
92
+ num_beams = gr.Slider(1, 10, value=DEFAULT_DECODE["num_beams"], step=1, label="num_beams")
93
+ max_new_tokens = gr.Slider(16, 512, value=DEFAULT_DECODE["max_new_tokens"], step=8, label="max_new_tokens")
94
+ no_repeat_ngram_size = gr.Slider(0, 10, value=DEFAULT_DECODE["no_repeat_ngram_size"], step=1, label="no_repeat_ngram_size")
95
+ length_penalty = gr.Slider(0.0, 2.0, value=DEFAULT_DECODE["length_penalty"], step=0.1, label="length_penalty")
96
+ do_sample = gr.Checkbox(False, label="Sampling (non-deterministic)")
97
+ temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="temperature")
98
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="top_p")
99
+ top_k = gr.Slider(0, 100, value=50, step=5, label="top_k")
100
+
101
+ btn.click(
102
+ translate,
103
+ inputs=[src, direction, num_beams, max_new_tokens, no_repeat_ngram_size,
104
+ length_penalty, do_sample, temperature, top_p, top_k],
105
+ outputs=out,
106
+ api_name="translate"
107
+ )
108
+
109
  demo.queue()
110
 
 
111
  if __name__ == "__main__":
112
  demo.launch()