Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
| 1 |
-
import os, torch
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# --------- Config (override via Space Variables if you like) ----------
|
| 7 |
TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
|
| 8 |
MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # e.g., "b1a2c3d"
|
| 12 |
-
MODEL_REV = os.getenv("MODEL_REV", None) # e.g., "e4f5a6b"
|
| 13 |
|
| 14 |
SRC_CODE = "eng_Latn"
|
| 15 |
HI_CODE = "hin_Deva"
|
|
@@ -20,27 +23,60 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
| 20 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 21 |
|
| 22 |
tok_kwargs = dict(trust_remote_code=True, use_fast=True)
|
| 23 |
-
if TOKENIZER_REV:
|
|
|
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
|
| 25 |
|
| 26 |
mdl_kwargs = dict(
|
| 27 |
trust_remote_code=True,
|
| 28 |
attn_implementation="eager",
|
| 29 |
low_cpu_mem_usage=True,
|
| 30 |
-
dtype=dtype,
|
| 31 |
)
|
| 32 |
-
if MODEL_REV:
|
|
|
|
|
|
|
| 33 |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
|
| 34 |
model.eval()
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
ip = IndicProcessor(inference=True)
|
| 37 |
|
| 38 |
# -------------------- Inference helpers -------------------------------
|
| 39 |
@torch.inference_mode()
|
| 40 |
def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
|
| 41 |
temperature: float, top_p: float, top_k: int):
|
|
|
|
| 42 |
batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
|
| 43 |
|
|
|
|
| 44 |
enc = tokenizer(
|
| 45 |
batch,
|
| 46 |
max_length=256,
|
|
@@ -50,8 +86,10 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
|
|
| 50 |
return_attention_mask=True,
|
| 51 |
).to(device)
|
| 52 |
|
|
|
|
| 53 |
do_sample = (temperature is not None) and (float(temperature) > 0)
|
| 54 |
|
|
|
|
| 55 |
outputs = model.generate(
|
| 56 |
**enc,
|
| 57 |
max_new_tokens=int(max_new_tokens),
|
|
@@ -62,16 +100,15 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
|
|
| 62 |
top_k=int(top_k) if do_sample else None,
|
| 63 |
use_cache=True,
|
| 64 |
early_stopping=False,
|
| 65 |
-
pad_token_id=
|
| 66 |
)
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
clean_up_tokenization_spaces=True,
|
| 73 |
-
)
|
| 74 |
|
|
|
|
| 75 |
final = ip.postprocess_batch(decoded, lang=tgt_code)
|
| 76 |
return final[0].strip()
|
| 77 |
|
|
@@ -79,15 +116,23 @@ def translate_dual(text, num_beams, max_new_tokens, temperature, top_p, top_k):
|
|
| 79 |
text = (text or "").strip()
|
| 80 |
if not text:
|
| 81 |
return "", ""
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
return hi, te
|
| 85 |
|
| 86 |
# -------------------- UI (professional, clean) ------------------------
|
| 87 |
-
THEME = gr.themes.Soft(
|
| 88 |
-
primary_hue="blue",
|
| 89 |
-
neutral_hue="slate",
|
| 90 |
-
).set(
|
| 91 |
body_background_fill="#0b1220",
|
| 92 |
body_text_color_subdued="#cbd5e1",
|
| 93 |
block_background_fill="#0f172a",
|
|
@@ -154,5 +199,5 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN鈫扝I / EN鈫扵E Translator"
|
|
| 154 |
|
| 155 |
gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B 路 Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
|
| 156 |
|
| 157 |
-
#
|
| 158 |
demo.queue(max_size=48).launch()
|
|
|
|
| 1 |
+
import os, traceback, types, torch
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
+
|
| 5 |
+
# Robust import for IndicProcessor (fallback path)
|
| 6 |
+
try:
|
| 7 |
+
from IndicTransToolkit import IndicProcessor # preferred
|
| 8 |
+
except Exception:
|
| 9 |
+
from IndicTransToolkit.IndicTransToolkit import IndicProcessor # fallback
|
| 10 |
|
| 11 |
# --------- Config (override via Space Variables if you like) ----------
|
| 12 |
TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
|
| 13 |
MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
|
| 14 |
+
TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # optional pin
|
| 15 |
+
MODEL_REV = os.getenv("MODEL_REV", None) # optional pin
|
|
|
|
|
|
|
| 16 |
|
| 17 |
SRC_CODE = "eng_Latn"
|
| 18 |
HI_CODE = "hin_Deva"
|
|
|
|
| 23 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 24 |
|
| 25 |
tok_kwargs = dict(trust_remote_code=True, use_fast=True)
|
| 26 |
+
if TOKENIZER_REV:
|
| 27 |
+
tok_kwargs["revision"] = TOKENIZER_REV
|
| 28 |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
|
| 29 |
|
| 30 |
mdl_kwargs = dict(
|
| 31 |
trust_remote_code=True,
|
| 32 |
attn_implementation="eager",
|
| 33 |
low_cpu_mem_usage=True,
|
| 34 |
+
dtype=dtype, # modern kw (no deprecation warning)
|
| 35 |
)
|
| 36 |
+
if MODEL_REV:
|
| 37 |
+
mdl_kwargs["revision"] = MODEL_REV
|
| 38 |
+
|
| 39 |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
|
| 40 |
model.eval()
|
| 41 |
|
| 42 |
+
# Patch generation config safety
|
| 43 |
+
if getattr(model.generation_config, "pad_token_id", None) is None:
|
| 44 |
+
model.generation_config.pad_token_id = (
|
| 45 |
+
getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", 0)
|
| 46 |
+
)
|
| 47 |
+
if getattr(model.generation_config, "eos_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
|
| 48 |
+
model.generation_config.eos_token_id = tokenizer.eos_token_id
|
| 49 |
+
|
| 50 |
+
# ---- Runtime compatibility patch for newer Transformers beam search ----
|
| 51 |
+
# Newer versions call self.config.get_text_config().vocab_size.
|
| 52 |
+
# Some custom configs (IndicTransConfig) don't define these.
|
| 53 |
+
if not hasattr(model.config, "vocab_size") or model.config.vocab_size is None:
|
| 54 |
+
try:
|
| 55 |
+
model.config.vocab_size = getattr(tokenizer, "vocab_size", None) or len(tokenizer)
|
| 56 |
+
except Exception:
|
| 57 |
+
# Fallback to a safe default if tokenizer doesn't expose size
|
| 58 |
+
model.config.vocab_size = 64000
|
| 59 |
+
if not hasattr(model.config, "get_text_config") or not callable(getattr(model.config, "get_text_config", None)):
|
| 60 |
+
def _get_text_config(self):
|
| 61 |
+
return self
|
| 62 |
+
model.config.get_text_config = types.MethodType(_get_text_config, model.config)
|
| 63 |
+
|
| 64 |
+
# Mirror into generation_config as well (some codepaths read there)
|
| 65 |
+
try:
|
| 66 |
+
model.generation_config.vocab_size = model.config.vocab_size
|
| 67 |
+
except Exception:
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
ip = IndicProcessor(inference=True)
|
| 71 |
|
| 72 |
# -------------------- Inference helpers -------------------------------
|
| 73 |
@torch.inference_mode()
|
| 74 |
def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
|
| 75 |
temperature: float, top_p: float, top_k: int):
|
| 76 |
+
# Preprocess via IndicTransToolkit
|
| 77 |
batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
|
| 78 |
|
| 79 |
+
# Tokenize
|
| 80 |
enc = tokenizer(
|
| 81 |
batch,
|
| 82 |
max_length=256,
|
|
|
|
| 86 |
return_attention_mask=True,
|
| 87 |
).to(device)
|
| 88 |
|
| 89 |
+
# Sampling toggles
|
| 90 |
do_sample = (temperature is not None) and (float(temperature) > 0)
|
| 91 |
|
| 92 |
+
# Generate
|
| 93 |
outputs = model.generate(
|
| 94 |
**enc,
|
| 95 |
max_new_tokens=int(max_new_tokens),
|
|
|
|
| 100 |
top_k=int(top_k) if do_sample else None,
|
| 101 |
use_cache=True,
|
| 102 |
early_stopping=False,
|
| 103 |
+
pad_token_id=model.generation_config.pad_token_id,
|
| 104 |
)
|
| 105 |
|
| 106 |
+
# Decode
|
| 107 |
+
decoded = tokenizer.batch_decode(
|
| 108 |
+
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
| 109 |
+
)
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
# Postprocess back to target script
|
| 112 |
final = ip.postprocess_batch(decoded, lang=tgt_code)
|
| 113 |
return final[0].strip()
|
| 114 |
|
|
|
|
| 116 |
text = (text or "").strip()
|
| 117 |
if not text:
|
| 118 |
return "", ""
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print("HI ERROR:\n", traceback.format_exc())
|
| 124 |
+
hi = f"鈿狅笍 Hindi translation failed: {type(e).__name__}: {str(e).splitlines()[-1]}"
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print("TE ERROR:\n", traceback.format_exc())
|
| 130 |
+
te = f"鈿狅笍 Telugu translation failed: {type(e).__name__}: {str(e).splitlines()[-1]}"
|
| 131 |
+
|
| 132 |
return hi, te
|
| 133 |
|
| 134 |
# -------------------- UI (professional, clean) ------------------------
|
| 135 |
+
THEME = gr.themes.Soft(primary_hue="blue", neutral_hue="slate").set(
|
|
|
|
|
|
|
|
|
|
| 136 |
body_background_fill="#0b1220",
|
| 137 |
body_text_color_subdued="#cbd5e1",
|
| 138 |
block_background_fill="#0f172a",
|
|
|
|
| 199 |
|
| 200 |
gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B 路 Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
|
| 201 |
|
| 202 |
+
# Keep queue to enable buffering; omit unsupported args on older Gradio
|
| 203 |
demo.queue(max_size=48).launch()
|