5AILingouCore commited on
Commit
ad5939d
·
verified ·
1 Parent(s): a317755

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +565 -565
app.py CHANGED
@@ -1,565 +1,565 @@
1
- # app.py — Hugging Face Spaces (Gradio) "全部入り" 翻訳アプリ
2
- # - Single translate + history table
3
- # - Batch translate (TXT/CSV) + download result CSV
4
- # - Glossary CSV (src,tgt)
5
- # - Model selector (m2m100 / opus-mt / nllb)
6
- # - Safe limits for free CPU Spaces
7
-
8
- import os
9
- import io
10
- import csv
11
- import time
12
- import json
13
- import tempfile
14
- from itertools import islice
15
- from typing import Dict, Optional, List, Tuple, Any
16
-
17
- import gradio as gr
18
- import torch
19
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
20
-
21
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
22
-
23
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
-
25
- # -------------------------
26
- # Model registry
27
- # -------------------------
28
- # NOTE:
29
- # - "opus-mt" is fast on CPU (recommended for free tier speed)
30
- # - "m2m100-418M" matches your current project
31
- # - "nllb-600M" can be heavier (quality often good, CPU slower)
32
- MODEL_SPECS: Dict[str, Dict[str, Any]] = {
33
- "m2m100-418M (multilingual, your current)": {
34
- "kind": "m2m100",
35
- "name": {"ja-en": "facebook/m2m100_418M", "en-ja": "facebook/m2m100_418M"},
36
- "lang": {"ja": "ja", "en": "en"},
37
- "needs_forced_bos": True,
38
- "supports_src_lang": True,
39
- },
40
- "opus-mt (fast CPU, ja<->en)": {
41
- "kind": "opus",
42
- "name": {"ja-en": "Helsinki-NLP/opus-mt-ja-en", "en-ja": "Helsinki-NLP/opus-mt-en-ja"},
43
- "lang": {"ja": None, "en": None},
44
- "needs_forced_bos": False,
45
- "supports_src_lang": False,
46
- },
47
- "nllb-600M (quality, heavier)": {
48
- "kind": "nllb",
49
- "name": {"ja-en": "facebook/nllb-200-distilled-600M", "en-ja": "facebook/nllb-200-distilled-600M"},
50
- "lang": {"ja": "jpn_Jpan", "en": "eng_Latn"},
51
- "needs_forced_bos": True,
52
- "supports_src_lang": True,
53
- },
54
- }
55
-
56
- # Cache: (model_key, direction) -> (tokenizer, model)
57
- TOK_CACHE: Dict[Tuple[str, str], Any] = {}
58
- MDL_CACHE: Dict[Tuple[str, str], Any] = {}
59
-
60
- # -------------------------
61
- # Safety limits (public space)
62
- # -------------------------
63
- MAX_SINGLE_CHARS = 4000 # single input max chars
64
- MAX_BATCH_LINES = 200 # batch line cap
65
- MAX_BATCH_CHARS_TOTAL = 20000 # batch total chars cap
66
- DEFAULT_MAX_NEW_TOKENS = 256
67
-
68
- # -------------------------
69
- # Helpers
70
- # -------------------------
71
- def detect_direction_by_text(text: str, prefer: str = "ja-en") -> str:
72
- """Simple heuristic: Japanese char => ja-en else en-ja."""
73
- for ch in text:
74
- if ("\u3040" <= ch <= "\u30ff") or ("\u4e00" <= ch <= "\u9fff"):
75
- return "ja-en"
76
- return "en-ja" if prefer == "ja-en" else "ja-en"
77
-
78
-
79
- def read_glossary_csv(path: Optional[str]) -> Optional[List[List[str]]]:
80
- """Read glossary CSV (src,tgt). UTF-8. No header assumed."""
81
- if not path:
82
- return None
83
- rows: List[List[str]] = []
84
- with open(path, "r", encoding="utf-8") as f:
85
- for r in csv.reader(f):
86
- if len(r) >= 2:
87
- src = (r[0] or "").strip()
88
- tgt = (r[1] or "").strip()
89
- if src:
90
- rows.append([src, tgt])
91
- return rows or None
92
-
93
-
94
- def apply_glossary(text: str, glossary: Optional[List[List[str]]]) -> str:
95
- if not glossary:
96
- return text
97
- out = text
98
- for src, tgt in glossary:
99
- if src:
100
- out = out.replace(src, tgt)
101
- return out
102
-
103
-
104
- def gen_kwargs_for_mode(conversation_mode: bool, base_beams: int) -> dict:
105
- """
106
- Stable defaults for public CPU:
107
- - Normal: deterministic beam search
108
- - Conversation: slightly more colloquial (beam-sampling) but still stable
109
- """
110
- if conversation_mode:
111
- return dict(
112
- do_sample=True,
113
- temperature=0.75,
114
- top_p=0.85,
115
- top_k=40,
116
- num_beams=max(1, min(2, int(base_beams))), # keep it small for stability
117
- repetition_penalty=1.08,
118
- )
119
- return dict(
120
- do_sample=False,
121
- num_beams=int(base_beams),
122
- repetition_penalty=1.05,
123
- )
124
-
125
-
126
- def _get_forced_bos_id(tokenizer, lang: str) -> Optional[int]:
127
- # M2M100: get_lang_id
128
- if hasattr(tokenizer, "get_lang_id"):
129
- try:
130
- return tokenizer.get_lang_id(lang)
131
- except Exception:
132
- pass
133
- # NLLB: lang_code_to_id
134
- if hasattr(tokenizer, "lang_code_to_id") and isinstance(getattr(tokenizer, "lang_code_to_id"), dict):
135
- if lang in tokenizer.lang_code_to_id:
136
- return tokenizer.lang_code_to_id[lang]
137
- # Fallback: token id
138
- try:
139
- return tokenizer.convert_tokens_to_ids(lang)
140
- except Exception:
141
- return None
142
-
143
-
144
- def _load_model(model_key: str, direction: str):
145
- """Lazy load + cache."""
146
- cache_key = (model_key, direction)
147
- if cache_key in TOK_CACHE:
148
- return TOK_CACHE[cache_key], MDL_CACHE[cache_key]
149
-
150
- spec = MODEL_SPECS[model_key]
151
- model_name = spec["name"][direction]
152
-
153
- tok = AutoTokenizer.from_pretrained(model_name)
154
-
155
- dtype = torch.float16 if DEVICE.type == "cuda" else torch.float32
156
- mdl = AutoModelForSeq2SeqLM.from_pretrained(
157
- model_name,
158
- torch_dtype=dtype,
159
- low_cpu_mem_usage=True,
160
- )
161
- mdl.to(DEVICE).eval()
162
-
163
- TOK_CACHE[cache_key] = tok
164
- MDL_CACHE[cache_key] = mdl
165
- return tok, mdl
166
-
167
-
168
- @torch.inference_mode()
169
- def translate_one(
170
- model_key: str,
171
- direction: str,
172
- text: str,
173
- max_new_tokens: int,
174
- num_beams: int,
175
- conversation: bool,
176
- ) -> str:
177
- tok, mdl = _load_model(model_key, direction)
178
- spec = MODEL_SPECS[model_key]
179
-
180
- # language tags (if supported)
181
- src_lang = spec["lang"]["ja" if direction == "ja-en" else "en"]
182
- tgt_lang = spec["lang"]["en" if direction == "ja-en" else "ja"]
183
-
184
- if spec.get("supports_src_lang") and hasattr(tok, "src_lang") and src_lang:
185
- tok.src_lang = src_lang
186
-
187
- inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
188
-
189
- gen_opts = gen_kwargs_for_mode(bool(conversation), int(num_beams))
190
-
191
- # forced BOS for multilingual models
192
- forced_id = None
193
- if spec.get("needs_forced_bos") and tgt_lang:
194
- forced_id = _get_forced_bos_id(tok, tgt_lang)
195
-
196
- generate_kwargs = dict(
197
- **inputs,
198
- max_new_tokens=int(max_new_tokens),
199
- no_repeat_ngram_size=3,
200
- length_penalty=1.05,
201
- **gen_opts,
202
- )
203
- if forced_id is not None:
204
- generate_kwargs["forced_bos_token_id"] = forced_id
205
-
206
- out_ids = mdl.generate(**generate_kwargs)
207
- return tok.batch_decode(out_ids, skip_special_tokens=True)[0]
208
-
209
-
210
- def _clamp_int(v: Any, lo: int, hi: int, default: int) -> int:
211
- try:
212
- x = int(v)
213
- return max(lo, min(hi, x))
214
- except Exception:
215
- return default
216
-
217
-
218
- def _history_to_table(history: List[Dict[str, str]]) -> List[List[str]]:
219
- # headers: time, direction, src, dst
220
- rows = []
221
- for item in history[-100:][::-1]: # show latest first, cap 100 rows
222
- rows.append([item["time"], item["direction"], item["src"], item["dst"]])
223
- return rows
224
-
225
-
226
- def _export_history(history: List[Dict[str, str]], fmt: str) -> str:
227
- tmpdir = tempfile.mkdtemp(prefix="history_")
228
- if fmt == "csv":
229
- path = os.path.join(tmpdir, "history.csv")
230
- with open(path, "w", newline="", encoding="utf-8-sig") as f:
231
- w = csv.writer(f)
232
- w.writerow(["time", "direction", "src", "dst"])
233
- for item in history:
234
- w.writerow([item["time"], item["direction"], item["src"], item["dst"]])
235
- return path
236
- else:
237
- path = os.path.join(tmpdir, "history.txt")
238
- with open(path, "w", encoding="utf-8") as f:
239
- for i, item in enumerate(history, 1):
240
- f.write(f"[{i}] {item['time']} | {item['direction']}\n")
241
- f.write(f"SRC: {item['src']}\n")
242
- f.write(f"DST: {item['dst']}\n")
243
- f.write("\n")
244
- return path
245
-
246
-
247
- def _read_batch_lines(file_path: str) -> List[str]:
248
- """
249
- Accept:
250
- - .txt: 1 line = 1 item
251
- - .csv: use first column as src (ignores header if it looks like header)
252
- """
253
- lower = (file_path or "").lower()
254
- lines: List[str] = []
255
-
256
- if lower.endswith(".csv"):
257
- with open(file_path, "r", encoding="utf-8") as f:
258
- r = csv.reader(f)
259
- for row in islice(r, MAX_BATCH_LINES + 5):
260
- if not row:
261
- continue
262
- val = (row[0] or "").strip()
263
- if not val:
264
- continue
265
- # naive header skip
266
- if len(lines) == 0 and val.lower() in ("src", "source", "text", "input"):
267
- continue
268
- lines.append(val)
269
- if len(lines) >= MAX_BATCH_LINES:
270
- break
271
- else:
272
- with open(file_path, "r", encoding="utf-8") as f:
273
- for ln in islice(f, MAX_BATCH_LINES):
274
- ln = ln.rstrip("\n").strip()
275
- if ln:
276
- lines.append(ln)
277
-
278
- # total chars guard
279
- total_chars = sum(len(x) for x in lines)
280
- if total_chars > MAX_BATCH_CHARS_TOTAL:
281
- # shrink until safe
282
- kept = []
283
- c = 0
284
- for s in lines:
285
- if c + len(s) > MAX_BATCH_CHARS_TOTAL:
286
- break
287
- kept.append(s)
288
- c += len(s)
289
- lines = kept
290
-
291
- return lines
292
-
293
-
294
- # -------------------------
295
- # Gradio handlers
296
- # -------------------------
297
- def warmup(model_key: str) -> str:
298
- t0 = time.time()
299
- try:
300
- _load_model(model_key, "ja-en")
301
- used = time.time() - t0
302
- return f"✅ Warmup OK ({used:.2f}s) — model: {model_key}"
303
- except Exception as e:
304
- return f"❌ Warmup failed: {e}"
305
-
306
-
307
- def do_translate(
308
- text: str,
309
- model_key: str,
310
- dir_choice: str,
311
- auto_on: bool,
312
- conversation_on: bool,
313
- glossary_path: Optional[str],
314
- max_new_tokens: int,
315
- num_beams: int,
316
- history: List[Dict[str, str]],
317
- ):
318
- text = (text or "").strip()
319
- if not text:
320
- return "", "⚠️ テキストを入力してください。", history, _history_to_table(history), gr.update(visible=False), gr.update(visible=False)
321
-
322
- if len(text) > MAX_SINGLE_CHARS:
323
- return "", f"⚠️ 入力が長すぎます(最大 {MAX_SINGLE_CHARS} 文字)。", history, _history_to_table(history), gr.update(visible=False), gr.update(visible=False)
324
-
325
- direction = detect_direction_by_text(text, prefer=dir_choice) if auto_on else dir_choice
326
- glossary = read_glossary_csv(glossary_path)
327
- src_processed = apply_glossary(text, glossary)
328
-
329
- max_new_tokens = _clamp_int(max_new_tokens, 16, 512, DEFAULT_MAX_NEW_TOKENS)
330
- num_beams = _clamp_int(num_beams, 1, 6, 4)
331
-
332
- t0 = time.time()
333
- try:
334
- out = translate_one(
335
- model_key=model_key,
336
- direction=direction,
337
- text=src_processed,
338
- max_new_tokens=max_new_tokens,
339
- num_beams=num_beams,
340
- conversation=bool(conversation_on),
341
- )
342
- used = time.time() - t0
343
-
344
- item = {
345
- "time": time.strftime("%Y-%m-%d %H:%M:%S"),
346
- "direction": direction,
347
- "src": text,
348
- "dst": out,
349
- }
350
- history = (history or []) + [item]
351
- table = _history_to_table(history)
352
-
353
- info = f"✅ 完了:{used:.2f}s|model: **{model_key}**|方向:**{direction}**|chars: {len(text)}"
354
- # show export buttons when history exists
355
- return out, info, history, table, gr.update(visible=True), gr.update(visible=True)
356
- except Exception as e:
357
- info = f"❌ 翻訳に失敗しました: {e}"
358
- return "", info, history, _history_to_table(history), gr.update(visible=bool(history)), gr.update(visible=bool(history))
359
-
360
-
361
- def clear_all(history: List[Dict[str, str]]):
362
- history = []
363
- return "", "🧹 クリアしました。", history, [], gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", ""
364
-
365
-
366
- def export_history_csv(history: List[Dict[str, str]]):
367
- if not history:
368
- return None
369
- return _export_history(history, "csv")
370
-
371
-
372
- def export_history_txt(history: List[Dict[str, str]]):
373
- if not history:
374
- return None
375
- return _export_history(history, "txt")
376
-
377
-
378
- def do_batch(
379
- batch_file_path: Optional[str],
380
- model_key: str,
381
- conversation_on: bool,
382
- glossary_path: Optional[str],
383
- max_new_tokens: int,
384
- num_beams: int,
385
- ):
386
- if not batch_file_path:
387
- yield "⚠️ バッチファイル(TXT/CSV)を選択してください。", "", gr.update(visible=False), None
388
- return
389
-
390
- lines = _read_batch_lines(batch_file_path)
391
- total = len(lines)
392
- if total == 0:
393
- yield "⚠️ 読み取れる行がありません(空/制限超過の可能性)。", "", gr.update(visible=False), None
394
- return
395
-
396
- glossary = read_glossary_csv(glossary_path)
397
- max_new_tokens = _clamp_int(max_new_tokens, 16, 512, DEFAULT_MAX_NEW_TOKENS)
398
- num_beams = _clamp_int(num_beams, 1, 6, 4)
399
-
400
- t0 = time.time()
401
- rows: List[Tuple[str, str, str]] = [] # (direction, src, dst)
402
-
403
- yield f"⏳ バッチ翻訳中… 0/{total} (0%)", "", gr.update(visible=False), None
404
-
405
- for i, src in enumerate(lines, 1):
406
- direction = detect_direction_by_text(src, prefer="ja-en")
407
- src_processed = apply_glossary(src, glossary)
408
-
409
- try:
410
- dst = translate_one(
411
- model_key=model_key,
412
- direction=direction,
413
- text=src_processed,
414
- max_new_tokens=max_new_tokens,
415
- num_beams=num_beams,
416
- conversation=bool(conversation_on),
417
- )
418
- except Exception as e:
419
- dst = f"[ERROR] {e}"
420
-
421
- rows.append((direction, src, dst))
422
-
423
- if i == 1 or i % 5 == 0 or i == total:
424
- pct = int(i * 100 / total)
425
- yield f"⏳ バッチ翻訳中… {i}/{total} ({pct}%)", "", gr.update(visible=False), None
426
-
427
- # Preview (limit)
428
- preview_lines = []
429
- for idx, (direction, s, d) in enumerate(rows[:50], 1):
430
- preview_lines.append(f"**{idx}. ({direction})**\n- SRC: {s}\n- DST: {d}\n")
431
- preview = "\n".join(preview_lines)
432
- if total > 50:
433
- preview += f"\n…(プレビューは先頭50行まで。全{total}行はCSVでダウンロード)"
434
-
435
- # Write result CSV
436
- tmpdir = tempfile.mkdtemp(prefix="batch_")
437
- out_path = os.path.join(tmpdir, "batch_result.csv")
438
- with open(out_path, "w", newline="", encoding="utf-8-sig") as f:
439
- w = csv.writer(f)
440
- w.writerow(["direction", "src", "dst"])
441
- for direction, s, d in rows:
442
- w.writerow([direction, s, d])
443
-
444
- used = time.time() - t0
445
- done_msg = f"✅ バッチ完了:{used:.2f}s|行数:{total}(最大{MAX_BATCH_LINES}行 / 合計{MAX_BATCH_CHARS_TOTAL}文字まで)"
446
- yield done_msg, preview, gr.update(visible=True), out_path
447
-
448
-
449
- # -------------------------
450
- # UI
451
- # -------------------------
452
- CUSTOM_CSS = """
453
- .gradio-container { max-width: 1100px !important; }
454
- .header-title { font-size: 34px; font-weight: 900; letter-spacing: .4px; margin: 6px 0 4px; }
455
- .subtle { opacity: 0.9; }
456
- .badge { display: inline-block; padding: 2px 10px; border-radius: 999px; border: 1px solid rgba(120,120,120,.35); font-size: 12px; }
457
- """
458
-
459
- with gr.Blocks(title="Linguo Core — Translation Space", css=CUSTOM_CSS) as demo:
460
- gr.HTML("<div class='header-title'>Linguo Core — Translation</div>")
461
- gr.Markdown(
462
- "<span class='badge'>HF Spaces</span> <span class='badge'>Public-safe</span> "
463
- "<span class='badge'>Glossary CSV</span> <span class='badge'>History</span> <span class='badge'>Batch</span>",
464
- elem_classes=["subtle"],
465
- )
466
-
467
- history_state = gr.State([]) # List[Dict]
468
-
469
- with gr.Row():
470
- model_key = gr.Dropdown(
471
- choices=list(MODEL_SPECS.keys()),
472
- value="m2m100-418M (multilingual, your current)",
473
- label="Model(無料CPUなら opus-mt が速い)",
474
- )
475
- warm = gr.Button("Warmup(初回ロード)")
476
-
477
- warm_info = gr.Markdown("")
478
-
479
- with gr.Row():
480
- direction = gr.Radio(["ja-en", "en-ja"], value="ja-en", label="Direction")
481
- auto = gr.Checkbox(value=True, label="Auto detect (日本語が含まれたら ja-en)")
482
- conversation = gr.Checkbox(value=False, label="Conversation mode(口語寄せ)")
483
-
484
- info = gr.Markdown("翻訳待機中…")
485
-
486
- with gr.Row(equal_height=True):
487
- with gr.Column(scale=1):
488
- src = gr.Textbox(lines=10, label="Input", placeholder="翻訳したい文章を入力…")
489
- with gr.Row():
490
- btn = gr.Button("Translate", variant="primary")
491
- btn_clear = gr.Button("Clear")
492
- with gr.Column(scale=1):
493
- dst = gr.Textbox(lines=10, label="Output", show_copy_button=True)
494
-
495
- with gr.Accordion("Glossary / Advanced / History / Batch", open=False):
496
- file_gloss = gr.File(label="Glossary CSV(src,tgt)", file_count="single", type="filepath")
497
-
498
- with gr.Row():
499
- max_len = gr.Slider(16, 512, DEFAULT_MAX_NEW_TOKENS, step=16, label="max_new_tokens")
500
- beams = gr.Slider(1, 6, 4, step=1, label="num_beams(通常モード向け)")
501
-
502
- gr.Markdown("### History(直近100件表示 / エクスポート可)")
503
- history_table = gr.Dataframe(
504
- headers=["time", "direction", "src", "dst"],
505
- datatype=["str", "str", "str", "str"],
506
- row_count=0,
507
- col_count=(4, "fixed"),
508
- wrap=True,
509
- interactive=False,
510
- value=[],
511
- label="History",
512
- )
513
- with gr.Row():
514
- btn_clear_history = gr.Button("Clear history")
515
- dl_hist_csv = gr.DownloadButton("Download history CSV", visible=False)
516
- dl_hist_txt = gr.DownloadButton("Download history TXT", visible=False)
517
-
518
- gr.Markdown("### Batch(TXT/CSV:1行=1件 / 公開Space保護で最大200行)")
519
- batch_file = gr.File(label="Batch file (TXT/CSV UTF-8)", file_count="single", type="filepath")
520
- btn_batch = gr.Button("Run batch translate")
521
- batch_status = gr.Markdown("")
522
- batch_preview = gr.Markdown("")
523
- dl_batch = gr.DownloadButton("Download batch_result.csv", visible=False)
524
-
525
- # Events
526
- warm.click(warmup, inputs=[model_key], outputs=[warm_info], queue=True)
527
-
528
- btn.click(
529
- do_translate,
530
- inputs=[src, model_key, direction, auto, conversation, file_gloss, max_len, beams, history_state],
531
- outputs=[dst, info, history_state, history_table, dl_hist_csv, dl_hist_txt],
532
- queue=True,
533
- )
534
- src.submit(
535
- do_translate,
536
- inputs=[src, model_key, direction, auto, conversation, file_gloss, max_len, beams, history_state],
537
- outputs=[dst, info, history_state, history_table, dl_hist_csv, dl_hist_txt],
538
- queue=True,
539
- )
540
-
541
- btn_clear.click(
542
- lambda h: ("", "🧹 入力をクリアしました。", h, _history_to_table(h), gr.update(visible=bool(h)), gr.update(visible=bool(h))),
543
- inputs=[history_state],
544
- outputs=[src, info, history_state, history_table, dl_hist_csv, dl_hist_txt],
545
- queue=False,
546
- )
547
-
548
- btn_clear_history.click(
549
- clear_all,
550
- inputs=[history_state],
551
- outputs=[src, info, history_state, history_table, dl_hist_csv, dl_hist_txt, dl_batch, batch_status, batch_preview],
552
- queue=False,
553
- )
554
-
555
- dl_hist_csv.click(export_history_csv, inputs=[history_state], outputs=[dl_hist_csv], queue=False)
556
- dl_hist_txt.click(export_history_txt, inputs=[history_state], outputs=[dl_hist_txt], queue=False)
557
-
558
- btn_batch.click(
559
- do_batch,
560
- inputs=[batch_file, model_key, conversation, file_gloss, max_len, beams],
561
- outputs=[batch_status, batch_preview, dl_batch, dl_batch],
562
- queue=True,
563
- )
564
-
565
- demo.queue(max_size=16, default_concurrency_limit=1).launch()
 
1
+ # app.py — Hugging Face Spaces (Gradio) "全部入り" 翻訳アプリ
2
+ # - Single translate + history table
3
+ # - Batch translate (TXT/CSV) + download result CSV
4
+ # - Glossary CSV (src,tgt)
5
+ # - Model selector (m2m100 / opus-mt / nllb)
6
+ # - Safe limits for free CPU Spaces
7
+
8
+ import os
9
+ import io
10
+ import csv
11
+ import time
12
+ import json
13
+ import tempfile
14
+ from itertools import islice
15
+ from typing import Dict, Optional, List, Tuple, Any
16
+
17
+ import gradio as gr
18
+ import torch
19
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
20
+
21
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
22
+
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # -------------------------
26
+ # Model registry
27
+ # -------------------------
28
+ # NOTE:
29
+ # - "opus-mt" is fast on CPU (recommended for free tier speed)
30
+ # - "m2m100-418M" matches your current project
31
+ # - "nllb-600M" can be heavier (quality often good, CPU slower)
32
+ MODEL_SPECS: Dict[str, Dict[str, Any]] = {
33
+ "m2m100-418M (multilingual, your current)": {
34
+ "kind": "m2m100",
35
+ "name": {"ja-en": "facebook/m2m100_418M", "en-ja": "facebook/m2m100_418M"},
36
+ "lang": {"ja": "ja", "en": "en"},
37
+ "needs_forced_bos": True,
38
+ "supports_src_lang": True,
39
+ },
40
+ "opus-mt (fast CPU, ja<->en)": {
41
+ "kind": "opus",
42
+ "name": {"ja-en": "Helsinki-NLP/opus-mt-ja-en", "en-ja": "Helsinki-NLP/opus-mt-en-ja"},
43
+ "lang": {"ja": None, "en": None},
44
+ "needs_forced_bos": False,
45
+ "supports_src_lang": False,
46
+ },
47
+ "nllb-600M (quality, heavier)": {
48
+ "kind": "nllb",
49
+ "name": {"ja-en": "facebook/nllb-200-distilled-600M", "en-ja": "facebook/nllb-200-distilled-600M"},
50
+ "lang": {"ja": "jpn_Jpan", "en": "eng_Latn"},
51
+ "needs_forced_bos": True,
52
+ "supports_src_lang": True,
53
+ },
54
+ }
55
+
56
+ # Cache: (model_key, direction) -> (tokenizer, model)
57
+ TOK_CACHE: Dict[Tuple[str, str], Any] = {}
58
+ MDL_CACHE: Dict[Tuple[str, str], Any] = {}
59
+
60
+ # -------------------------
61
+ # Safety limits (public space)
62
+ # -------------------------
63
+ MAX_SINGLE_CHARS = 4000 # single input max chars
64
+ MAX_BATCH_LINES = 200 # batch line cap
65
+ MAX_BATCH_CHARS_TOTAL = 20000 # batch total chars cap
66
+ DEFAULT_MAX_NEW_TOKENS = 256
67
+
68
+ # -------------------------
69
+ # Helpers
70
+ # -------------------------
71
+ def detect_direction_by_text(text: str, prefer: str = "ja-en") -> str:
72
+ """Simple heuristic: Japanese char => ja-en else en-ja."""
73
+ for ch in text:
74
+ if ("\u3040" <= ch <= "\u30ff") or ("\u4e00" <= ch <= "\u9fff"):
75
+ return "ja-en"
76
+ return "en-ja" if prefer == "ja-en" else "ja-en"
77
+
78
+
79
+ def read_glossary_csv(path: Optional[str]) -> Optional[List[List[str]]]:
80
+ """Read glossary CSV (src,tgt). UTF-8. No header assumed."""
81
+ if not path:
82
+ return None
83
+ rows: List[List[str]] = []
84
+ with open(path, "r", encoding="utf-8") as f:
85
+ for r in csv.reader(f):
86
+ if len(r) >= 2:
87
+ src = (r[0] or "").strip()
88
+ tgt = (r[1] or "").strip()
89
+ if src:
90
+ rows.append([src, tgt])
91
+ return rows or None
92
+
93
+
94
+ def apply_glossary(text: str, glossary: Optional[List[List[str]]]) -> str:
95
+ if not glossary:
96
+ return text
97
+ out = text
98
+ for src, tgt in glossary:
99
+ if src:
100
+ out = out.replace(src, tgt)
101
+ return out
102
+
103
+
104
+ def gen_kwargs_for_mode(conversation_mode: bool, base_beams: int) -> dict:
105
+ """
106
+ Stable defaults for public CPU:
107
+ - Normal: deterministic beam search
108
+ - Conversation: slightly more colloquial (beam-sampling) but still stable
109
+ """
110
+ if conversation_mode:
111
+ return dict(
112
+ do_sample=True,
113
+ temperature=0.75,
114
+ top_p=0.85,
115
+ top_k=40,
116
+ num_beams=max(1, min(2, int(base_beams))), # keep it small for stability
117
+ repetition_penalty=1.08,
118
+ )
119
+ return dict(
120
+ do_sample=False,
121
+ num_beams=int(base_beams),
122
+ repetition_penalty=1.05,
123
+ )
124
+
125
+
126
+ def _get_forced_bos_id(tokenizer, lang: str) -> Optional[int]:
127
+ # M2M100: get_lang_id
128
+ if hasattr(tokenizer, "get_lang_id"):
129
+ try:
130
+ return tokenizer.get_lang_id(lang)
131
+ except Exception:
132
+ pass
133
+ # NLLB: lang_code_to_id
134
+ if hasattr(tokenizer, "lang_code_to_id") and isinstance(getattr(tokenizer, "lang_code_to_id"), dict):
135
+ if lang in tokenizer.lang_code_to_id:
136
+ return tokenizer.lang_code_to_id[lang]
137
+ # Fallback: token id
138
+ try:
139
+ return tokenizer.convert_tokens_to_ids(lang)
140
+ except Exception:
141
+ return None
142
+
143
+
144
+ def _load_model(model_key: str, direction: str):
145
+ """Lazy load + cache."""
146
+ cache_key = (model_key, direction)
147
+ if cache_key in TOK_CACHE:
148
+ return TOK_CACHE[cache_key], MDL_CACHE[cache_key]
149
+
150
+ spec = MODEL_SPECS[model_key]
151
+ model_name = spec["name"][direction]
152
+
153
+ tok = AutoTokenizer.from_pretrained(model_name)
154
+
155
+ dtype = torch.float16 if DEVICE.type == "cuda" else torch.float32
156
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(
157
+ model_name,
158
+ torch_dtype=dtype,
159
+ low_cpu_mem_usage=True,
160
+ )
161
+ mdl.to(DEVICE).eval()
162
+
163
+ TOK_CACHE[cache_key] = tok
164
+ MDL_CACHE[cache_key] = mdl
165
+ return tok, mdl
166
+
167
+
168
+ @torch.inference_mode()
169
+ def translate_one(
170
+ model_key: str,
171
+ direction: str,
172
+ text: str,
173
+ max_new_tokens: int,
174
+ num_beams: int,
175
+ conversation: bool,
176
+ ) -> str:
177
+ tok, mdl = _load_model(model_key, direction)
178
+ spec = MODEL_SPECS[model_key]
179
+
180
+ # language tags (if supported)
181
+ src_lang = spec["lang"]["ja" if direction == "ja-en" else "en"]
182
+ tgt_lang = spec["lang"]["en" if direction == "ja-en" else "ja"]
183
+
184
+ if spec.get("supports_src_lang") and hasattr(tok, "src_lang") and src_lang:
185
+ tok.src_lang = src_lang
186
+
187
+ inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
188
+
189
+ gen_opts = gen_kwargs_for_mode(bool(conversation), int(num_beams))
190
+
191
+ # forced BOS for multilingual models
192
+ forced_id = None
193
+ if spec.get("needs_forced_bos") and tgt_lang:
194
+ forced_id = _get_forced_bos_id(tok, tgt_lang)
195
+
196
+ generate_kwargs = dict(
197
+ **inputs,
198
+ max_new_tokens=int(max_new_tokens),
199
+ no_repeat_ngram_size=3,
200
+ length_penalty=1.05,
201
+ **gen_opts,
202
+ )
203
+ if forced_id is not None:
204
+ generate_kwargs["forced_bos_token_id"] = forced_id
205
+
206
+ out_ids = mdl.generate(**generate_kwargs)
207
+ return tok.batch_decode(out_ids, skip_special_tokens=True)[0]
208
+
209
+
210
+ def _clamp_int(v: Any, lo: int, hi: int, default: int) -> int:
211
+ try:
212
+ x = int(v)
213
+ return max(lo, min(hi, x))
214
+ except Exception:
215
+ return default
216
+
217
+
218
+ def _history_to_table(history: List[Dict[str, str]]) -> List[List[str]]:
219
+ # headers: time, direction, src, dst
220
+ rows = []
221
+ for item in history[-100:][::-1]: # show latest first, cap 100 rows
222
+ rows.append([item["time"], item["direction"], item["src"], item["dst"]])
223
+ return rows
224
+
225
+
226
+ def _export_history(history: List[Dict[str, str]], fmt: str) -> str:
227
+ tmpdir = tempfile.mkdtemp(prefix="history_")
228
+ if fmt == "csv":
229
+ path = os.path.join(tmpdir, "history.csv")
230
+ with open(path, "w", newline="", encoding="utf-8-sig") as f:
231
+ w = csv.writer(f)
232
+ w.writerow(["time", "direction", "src", "dst"])
233
+ for item in history:
234
+ w.writerow([item["time"], item["direction"], item["src"], item["dst"]])
235
+ return path
236
+ else:
237
+ path = os.path.join(tmpdir, "history.txt")
238
+ with open(path, "w", encoding="utf-8") as f:
239
+ for i, item in enumerate(history, 1):
240
+ f.write(f"[{i}] {item['time']} | {item['direction']}\n")
241
+ f.write(f"SRC: {item['src']}\n")
242
+ f.write(f"DST: {item['dst']}\n")
243
+ f.write("\n")
244
+ return path
245
+
246
+
247
+ def _read_batch_lines(file_path: str) -> List[str]:
248
+ """
249
+ Accept:
250
+ - .txt: 1 line = 1 item
251
+ - .csv: use first column as src (ignores header if it looks like header)
252
+ """
253
+ lower = (file_path or "").lower()
254
+ lines: List[str] = []
255
+
256
+ if lower.endswith(".csv"):
257
+ with open(file_path, "r", encoding="utf-8") as f:
258
+ r = csv.reader(f)
259
+ for row in islice(r, MAX_BATCH_LINES + 5):
260
+ if not row:
261
+ continue
262
+ val = (row[0] or "").strip()
263
+ if not val:
264
+ continue
265
+ # naive header skip
266
+ if len(lines) == 0 and val.lower() in ("src", "source", "text", "input"):
267
+ continue
268
+ lines.append(val)
269
+ if len(lines) >= MAX_BATCH_LINES:
270
+ break
271
+ else:
272
+ with open(file_path, "r", encoding="utf-8") as f:
273
+ for ln in islice(f, MAX_BATCH_LINES):
274
+ ln = ln.rstrip("\n").strip()
275
+ if ln:
276
+ lines.append(ln)
277
+
278
+ # total chars guard
279
+ total_chars = sum(len(x) for x in lines)
280
+ if total_chars > MAX_BATCH_CHARS_TOTAL:
281
+ # shrink until safe
282
+ kept = []
283
+ c = 0
284
+ for s in lines:
285
+ if c + len(s) > MAX_BATCH_CHARS_TOTAL:
286
+ break
287
+ kept.append(s)
288
+ c += len(s)
289
+ lines = kept
290
+
291
+ return lines
292
+
293
+
294
+ # -------------------------
295
+ # Gradio handlers
296
+ # -------------------------
297
+ def warmup(model_key: str) -> str:
298
+ t0 = time.time()
299
+ try:
300
+ _load_model(model_key, "ja-en")
301
+ used = time.time() - t0
302
+ return f"✅ Warmup OK ({used:.2f}s) — model: {model_key}"
303
+ except Exception as e:
304
+ return f"❌ Warmup failed: {e}"
305
+
306
+
307
+ def do_translate(
308
+ text: str,
309
+ model_key: str,
310
+ dir_choice: str,
311
+ auto_on: bool,
312
+ conversation_on: bool,
313
+ glossary_path: Optional[str],
314
+ max_new_tokens: int,
315
+ num_beams: int,
316
+ history: List[Dict[str, str]],
317
+ ):
318
+ text = (text or "").strip()
319
+ if not text:
320
+ return "", "⚠️ テキストを入力してください。", history, _history_to_table(history), gr.update(visible=False), gr.update(visible=False)
321
+
322
+ if len(text) > MAX_SINGLE_CHARS:
323
+ return "", f"⚠️ 入力が長すぎます(最大 {MAX_SINGLE_CHARS} 文字)。", history, _history_to_table(history), gr.update(visible=False), gr.update(visible=False)
324
+
325
+ direction = detect_direction_by_text(text, prefer=dir_choice) if auto_on else dir_choice
326
+ glossary = read_glossary_csv(glossary_path)
327
+ src_processed = apply_glossary(text, glossary)
328
+
329
+ max_new_tokens = _clamp_int(max_new_tokens, 16, 512, DEFAULT_MAX_NEW_TOKENS)
330
+ num_beams = _clamp_int(num_beams, 1, 6, 4)
331
+
332
+ t0 = time.time()
333
+ try:
334
+ out = translate_one(
335
+ model_key=model_key,
336
+ direction=direction,
337
+ text=src_processed,
338
+ max_new_tokens=max_new_tokens,
339
+ num_beams=num_beams,
340
+ conversation=bool(conversation_on),
341
+ )
342
+ used = time.time() - t0
343
+
344
+ item = {
345
+ "time": time.strftime("%Y-%m-%d %H:%M:%S"),
346
+ "direction": direction,
347
+ "src": text,
348
+ "dst": out,
349
+ }
350
+ history = (history or []) + [item]
351
+ table = _history_to_table(history)
352
+
353
+ info = f"✅ 完了:{used:.2f}s|model: **{model_key}**|方向:**{direction}**|chars: {len(text)}"
354
+ # show export buttons when history exists
355
+ return out, info, history, table, gr.update(visible=True), gr.update(visible=True)
356
+ except Exception as e:
357
+ info = f"❌ 翻訳に失敗しました: {e}"
358
+ return "", info, history, _history_to_table(history), gr.update(visible=bool(history)), gr.update(visible=bool(history))
359
+
360
+
361
+ def clear_all(history: List[Dict[str, str]]):
362
+ history = []
363
+ return "", "🧹 クリアしました。", history, [], gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", ""
364
+
365
+
366
+ def export_history_csv(history: List[Dict[str, str]]):
367
+ if not history:
368
+ return None
369
+ return _export_history(history, "csv")
370
+
371
+
372
+ def export_history_txt(history: List[Dict[str, str]]):
373
+ if not history:
374
+ return None
375
+ return _export_history(history, "txt")
376
+
377
+
378
+ def do_batch(
379
+ batch_file_path: Optional[str],
380
+ model_key: str,
381
+ conversation_on: bool,
382
+ glossary_path: Optional[str],
383
+ max_new_tokens: int,
384
+ num_beams: int,
385
+ ):
386
+ if not batch_file_path:
387
+ yield "⚠️ バッチファイル(TXT/CSV)を選択してください。", "", gr.update(visible=False), None
388
+ return
389
+
390
+ lines = _read_batch_lines(batch_file_path)
391
+ total = len(lines)
392
+ if total == 0:
393
+ yield "⚠️ 読み取れる行がありません(空/制限超過の可能性)。", "", gr.update(visible=False), None
394
+ return
395
+
396
+ glossary = read_glossary_csv(glossary_path)
397
+ max_new_tokens = _clamp_int(max_new_tokens, 16, 512, DEFAULT_MAX_NEW_TOKENS)
398
+ num_beams = _clamp_int(num_beams, 1, 6, 4)
399
+
400
+ t0 = time.time()
401
+ rows: List[Tuple[str, str, str]] = [] # (direction, src, dst)
402
+
403
+ yield f"⏳ バッチ翻訳中… 0/{total} (0%)", "", gr.update(visible=False), None
404
+
405
+ for i, src in enumerate(lines, 1):
406
+ direction = detect_direction_by_text(src, prefer="ja-en")
407
+ src_processed = apply_glossary(src, glossary)
408
+
409
+ try:
410
+ dst = translate_one(
411
+ model_key=model_key,
412
+ direction=direction,
413
+ text=src_processed,
414
+ max_new_tokens=max_new_tokens,
415
+ num_beams=num_beams,
416
+ conversation=bool(conversation_on),
417
+ )
418
+ except Exception as e:
419
+ dst = f"[ERROR] {e}"
420
+
421
+ rows.append((direction, src, dst))
422
+
423
+ if i == 1 or i % 5 == 0 or i == total:
424
+ pct = int(i * 100 / total)
425
+ yield f"⏳ バッチ翻訳中… {i}/{total} ({pct}%)", "", gr.update(visible=False), None
426
+
427
+ # Preview (limit)
428
+ preview_lines = []
429
+ for idx, (direction, s, d) in enumerate(rows[:50], 1):
430
+ preview_lines.append(f"**{idx}. ({direction})**\n- SRC: {s}\n- DST: {d}\n")
431
+ preview = "\n".join(preview_lines)
432
+ if total > 50:
433
+ preview += f"\n…(プレビューは先頭50行まで。全{total}行はCSVでダウンロード)"
434
+
435
+ # Write result CSV
436
+ tmpdir = tempfile.mkdtemp(prefix="batch_")
437
+ out_path = os.path.join(tmpdir, "batch_result.csv")
438
+ with open(out_path, "w", newline="", encoding="utf-8-sig") as f:
439
+ w = csv.writer(f)
440
+ w.writerow(["direction", "src", "dst"])
441
+ for direction, s, d in rows:
442
+ w.writerow([direction, s, d])
443
+
444
+ used = time.time() - t0
445
+ done_msg = f"✅ バッチ完了:{used:.2f}s|行数:{total}(最大{MAX_BATCH_LINES}行 / 合計{MAX_BATCH_CHARS_TOTAL}文字まで)"
446
+ yield done_msg, preview, gr.update(visible=True), out_path
447
+
448
+
449
+ # -------------------------
450
+ # UI
451
+ # -------------------------
452
+ CUSTOM_CSS = """
453
+ .gradio-container { max-width: 1100px !important; }
454
+ .header-title { font-size: 34px; font-weight: 900; letter-spacing: .4px; margin: 6px 0 4px; }
455
+ .subtle { opacity: 0.9; }
456
+ .badge { display: inline-block; padding: 2px 10px; border-radius: 999px; border: 1px solid rgba(120,120,120,.35); font-size: 12px; }
457
+ """
458
+
459
+ with gr.Blocks(title="Linguo Core — Translation Space") as demo:
460
+ gr.HTML(f"<style>{CUSTOM_CSS}</style>")
461
+ gr.Markdown(
462
+ "<span class='badge'>HF Spaces</span> <span class='badge'>Public-safe</span> "
463
+ "<span class='badge'>Glossary CSV</span> <span class='badge'>History</span> <span class='badge'>Batch</span>",
464
+ elem_classes=["subtle"],
465
+ )
466
+
467
+ history_state = gr.State([]) # List[Dict]
468
+
469
+ with gr.Row():
470
+ model_key = gr.Dropdown(
471
+ choices=list(MODEL_SPECS.keys()),
472
+ value="m2m100-418M (multilingual, your current)",
473
+ label="Model(無料CPUなら opus-mt が速い)",
474
+ )
475
+ warm = gr.Button("Warmup(初回ロード)")
476
+
477
+ warm_info = gr.Markdown("")
478
+
479
+ with gr.Row():
480
+ direction = gr.Radio(["ja-en", "en-ja"], value="ja-en", label="Direction")
481
+ auto = gr.Checkbox(value=True, label="Auto detect (日本語が含まれたら ja-en)")
482
+ conversation = gr.Checkbox(value=False, label="Conversation mode(口語寄せ)")
483
+
484
+ info = gr.Markdown("翻訳待機中…")
485
+
486
+ with gr.Row(equal_height=True):
487
+ with gr.Column(scale=1):
488
+ src = gr.Textbox(lines=10, label="Input", placeholder="翻訳したい文章を入力…")
489
+ with gr.Row():
490
+ btn = gr.Button("Translate", variant="primary")
491
+ btn_clear = gr.Button("Clear")
492
+ with gr.Column(scale=1):
493
+ dst = gr.Textbox(lines=10, label="Output", show_copy_button=True)
494
+
495
+ with gr.Accordion("Glossary / Advanced / History / Batch", open=False):
496
+ file_gloss = gr.File(label="Glossary CSV(src,tgt)", file_count="single", type="filepath")
497
+
498
+ with gr.Row():
499
+ max_len = gr.Slider(16, 512, DEFAULT_MAX_NEW_TOKENS, step=16, label="max_new_tokens")
500
+ beams = gr.Slider(1, 6, 4, step=1, label="num_beams(通常モード向け)")
501
+
502
+ gr.Markdown("### History(直近100件表示 / エクスポート可)")
503
+ history_table = gr.Dataframe(
504
+ headers=["time", "direction", "src", "dst"],
505
+ datatype=["str", "str", "str", "str"],
506
+ row_count=0,
507
+ col_count=(4, "fixed"),
508
+ wrap=True,
509
+ interactive=False,
510
+ value=[],
511
+ label="History",
512
+ )
513
+ with gr.Row():
514
+ btn_clear_history = gr.Button("Clear history")
515
+ dl_hist_csv = gr.DownloadButton("Download history CSV", visible=False)
516
+ dl_hist_txt = gr.DownloadButton("Download history TXT", visible=False)
517
+
518
+ gr.Markdown("### Batch(TXT/CSV:1行=1件 / 公開Space保護で最大200行)")
519
+ batch_file = gr.File(label="Batch file (TXT/CSV UTF-8)", file_count="single", type="filepath")
520
+ btn_batch = gr.Button("Run batch translate")
521
+ batch_status = gr.Markdown("")
522
+ batch_preview = gr.Markdown("")
523
+ dl_batch = gr.DownloadButton("Download batch_result.csv", visible=False)
524
+
525
+ # Events
526
+ warm.click(warmup, inputs=[model_key], outputs=[warm_info], queue=True)
527
+
528
+ btn.click(
529
+ do_translate,
530
+ inputs=[src, model_key, direction, auto, conversation, file_gloss, max_len, beams, history_state],
531
+ outputs=[dst, info, history_state, history_table, dl_hist_csv, dl_hist_txt],
532
+ queue=True,
533
+ )
534
+ src.submit(
535
+ do_translate,
536
+ inputs=[src, model_key, direction, auto, conversation, file_gloss, max_len, beams, history_state],
537
+ outputs=[dst, info, history_state, history_table, dl_hist_csv, dl_hist_txt],
538
+ queue=True,
539
+ )
540
+
541
+ btn_clear.click(
542
+ lambda h: ("", "🧹 入力をクリアしました。", h, _history_to_table(h), gr.update(visible=bool(h)), gr.update(visible=bool(h))),
543
+ inputs=[history_state],
544
+ outputs=[src, info, history_state, history_table, dl_hist_csv, dl_hist_txt],
545
+ queue=False,
546
+ )
547
+
548
+ btn_clear_history.click(
549
+ clear_all,
550
+ inputs=[history_state],
551
+ outputs=[src, info, history_state, history_table, dl_hist_csv, dl_hist_txt, dl_batch, batch_status, batch_preview],
552
+ queue=False,
553
+ )
554
+
555
+ dl_hist_csv.click(export_history_csv, inputs=[history_state], outputs=[dl_hist_csv], queue=False)
556
+ dl_hist_txt.click(export_history_txt, inputs=[history_state], outputs=[dl_hist_txt], queue=False)
557
+
558
+ btn_batch.click(
559
+ do_batch,
560
+ inputs=[batch_file, model_key, conversation, file_gloss, max_len, beams],
561
+ outputs=[batch_status, batch_preview, dl_batch, dl_batch],
562
+ queue=True,
563
+ )
564
+
565
+ demo.queue(max_size=16, default_concurrency_limit=1).launch()