SagarVelamuri commited on
Commit
c6914e7
verified
1 Parent(s): 9e61cae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -23
app.py CHANGED
@@ -1,15 +1,18 @@
1
- import os, torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- from IndicTransToolkit import IndicProcessor # https://github.com/VarunGumma/IndicTransToolkit
 
 
 
 
 
5
 
6
  # --------- Config (override via Space Variables if you like) ----------
7
  TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
8
  MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
9
-
10
- # (Optional) pin revisions to avoid surprise upstream changes
11
- TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # e.g., "b1a2c3d"
12
- MODEL_REV = os.getenv("MODEL_REV", None) # e.g., "e4f5a6b"
13
 
14
  SRC_CODE = "eng_Latn"
15
  HI_CODE = "hin_Deva"
@@ -20,27 +23,60 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
 
22
  tok_kwargs = dict(trust_remote_code=True, use_fast=True)
23
- if TOKENIZER_REV: tok_kwargs["revision"] = TOKENIZER_REV
 
24
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
25
 
26
  mdl_kwargs = dict(
27
  trust_remote_code=True,
28
  attn_implementation="eager",
29
  low_cpu_mem_usage=True,
30
- dtype=dtype, # <- fixes the torch_dtype deprecation warning
31
  )
32
- if MODEL_REV: mdl_kwargs["revision"] = MODEL_REV
 
 
33
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
34
  model.eval()
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ip = IndicProcessor(inference=True)
37
 
38
  # -------------------- Inference helpers -------------------------------
39
  @torch.inference_mode()
40
  def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
41
  temperature: float, top_p: float, top_k: int):
 
42
  batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
43
 
 
44
  enc = tokenizer(
45
  batch,
46
  max_length=256,
@@ -50,8 +86,10 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
50
  return_attention_mask=True,
51
  ).to(device)
52
 
 
53
  do_sample = (temperature is not None) and (float(temperature) > 0)
54
 
 
55
  outputs = model.generate(
56
  **enc,
57
  max_new_tokens=int(max_new_tokens),
@@ -62,16 +100,15 @@ def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens:
62
  top_k=int(top_k) if do_sample else None,
63
  use_cache=True,
64
  early_stopping=False,
65
- pad_token_id=tokenizer.pad_token_id or 0,
66
  )
67
 
68
- with tokenizer.as_target_tokenizer():
69
- decoded = tokenizer.batch_decode(
70
- outputs.detach().cpu().tolist(),
71
- skip_special_tokens=True,
72
- clean_up_tokenization_spaces=True,
73
- )
74
 
 
75
  final = ip.postprocess_batch(decoded, lang=tgt_code)
76
  return final[0].strip()
77
 
@@ -79,15 +116,23 @@ def translate_dual(text, num_beams, max_new_tokens, temperature, top_p, top_k):
79
  text = (text or "").strip()
80
  if not text:
81
  return "", ""
82
- hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
83
- te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
 
 
 
 
 
 
 
 
 
 
 
84
  return hi, te
85
 
86
  # -------------------- UI (professional, clean) ------------------------
87
- THEME = gr.themes.Soft(
88
- primary_hue="blue",
89
- neutral_hue="slate",
90
- ).set(
91
  body_background_fill="#0b1220",
92
  body_text_color_subdued="#cbd5e1",
93
  block_background_fill="#0f172a",
@@ -154,5 +199,5 @@ with gr.Blocks(theme=THEME, css=CUSTOM_CSS, title="EN鈫扝I / EN鈫扵E Translator"
154
 
155
  gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B 路 Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
156
 
157
- # IMPORTANT: remove unsupported arg; keep queue to enable request buffering
158
  demo.queue(max_size=48).launch()
 
1
+ import os, traceback, types, torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ # Robust import for IndicProcessor (fallback path)
6
+ try:
7
+ from IndicTransToolkit import IndicProcessor # preferred
8
+ except Exception:
9
+ from IndicTransToolkit.IndicTransToolkit import IndicProcessor # fallback
10
 
11
  # --------- Config (override via Space Variables if you like) ----------
12
  TOKENIZER_ID = os.getenv("TOKENIZER_ID", "ai4bharat/indictrans2-en-indic-1B")
13
  MODEL_ID = os.getenv("MODEL_ID", "law-ai/InLegalTrans-En2Indic-1B")
14
+ TOKENIZER_REV = os.getenv("TOKENIZER_REV", None) # optional pin
15
+ MODEL_REV = os.getenv("MODEL_REV", None) # optional pin
 
 
16
 
17
  SRC_CODE = "eng_Latn"
18
  HI_CODE = "hin_Deva"
 
23
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
 
25
  tok_kwargs = dict(trust_remote_code=True, use_fast=True)
26
+ if TOKENIZER_REV:
27
+ tok_kwargs["revision"] = TOKENIZER_REV
28
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, **tok_kwargs)
29
 
30
  mdl_kwargs = dict(
31
  trust_remote_code=True,
32
  attn_implementation="eager",
33
  low_cpu_mem_usage=True,
34
+ dtype=dtype, # modern kw (no deprecation warning)
35
  )
36
+ if MODEL_REV:
37
+ mdl_kwargs["revision"] = MODEL_REV
38
+
39
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID, **mdl_kwargs).to(device)
40
  model.eval()
41
 
42
+ # Patch generation config safety
43
+ if getattr(model.generation_config, "pad_token_id", None) is None:
44
+ model.generation_config.pad_token_id = (
45
+ getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", 0)
46
+ )
47
+ if getattr(model.generation_config, "eos_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
48
+ model.generation_config.eos_token_id = tokenizer.eos_token_id
49
+
50
+ # ---- Runtime compatibility patch for newer Transformers beam search ----
51
+ # Newer versions call self.config.get_text_config().vocab_size.
52
+ # Some custom configs (IndicTransConfig) don't define these.
53
+ if not hasattr(model.config, "vocab_size") or model.config.vocab_size is None:
54
+ try:
55
+ model.config.vocab_size = getattr(tokenizer, "vocab_size", None) or len(tokenizer)
56
+ except Exception:
57
+ # Fallback to a safe default if tokenizer doesn't expose size
58
+ model.config.vocab_size = 64000
59
+ if not hasattr(model.config, "get_text_config") or not callable(getattr(model.config, "get_text_config", None)):
60
+ def _get_text_config(self):
61
+ return self
62
+ model.config.get_text_config = types.MethodType(_get_text_config, model.config)
63
+
64
+ # Mirror into generation_config as well (some codepaths read there)
65
+ try:
66
+ model.generation_config.vocab_size = model.config.vocab_size
67
+ except Exception:
68
+ pass
69
+
70
  ip = IndicProcessor(inference=True)
71
 
72
  # -------------------- Inference helpers -------------------------------
73
  @torch.inference_mode()
74
  def _translate_to_lang(text: str, tgt_code: str, num_beams: int, max_new_tokens: int,
75
  temperature: float, top_p: float, top_k: int):
76
+ # Preprocess via IndicTransToolkit
77
  batch = ip.preprocess_batch([text], src_lang=SRC_CODE, tgt_lang=tgt_code)
78
 
79
+ # Tokenize
80
  enc = tokenizer(
81
  batch,
82
  max_length=256,
 
86
  return_attention_mask=True,
87
  ).to(device)
88
 
89
+ # Sampling toggles
90
  do_sample = (temperature is not None) and (float(temperature) > 0)
91
 
92
+ # Generate
93
  outputs = model.generate(
94
  **enc,
95
  max_new_tokens=int(max_new_tokens),
 
100
  top_k=int(top_k) if do_sample else None,
101
  use_cache=True,
102
  early_stopping=False,
103
+ pad_token_id=model.generation_config.pad_token_id,
104
  )
105
 
106
+ # Decode
107
+ decoded = tokenizer.batch_decode(
108
+ outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
109
+ )
 
 
110
 
111
+ # Postprocess back to target script
112
  final = ip.postprocess_batch(decoded, lang=tgt_code)
113
  return final[0].strip()
114
 
 
116
  text = (text or "").strip()
117
  if not text:
118
  return "", ""
119
+
120
+ try:
121
+ hi = _translate_to_lang(text, HI_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
122
+ except Exception as e:
123
+ print("HI ERROR:\n", traceback.format_exc())
124
+ hi = f"鈿狅笍 Hindi translation failed: {type(e).__name__}: {str(e).splitlines()[-1]}"
125
+
126
+ try:
127
+ te = _translate_to_lang(text, TE_CODE, num_beams, max_new_tokens, temperature, top_p, top_k)
128
+ except Exception as e:
129
+ print("TE ERROR:\n", traceback.format_exc())
130
+ te = f"鈿狅笍 Telugu translation failed: {type(e).__name__}: {str(e).splitlines()[-1]}"
131
+
132
  return hi, te
133
 
134
  # -------------------- UI (professional, clean) ------------------------
135
+ THEME = gr.themes.Soft(primary_hue="blue", neutral_hue="slate").set(
 
 
 
136
  body_background_fill="#0b1220",
137
  body_text_color_subdued="#cbd5e1",
138
  block_background_fill="#0f172a",
 
199
 
200
  gr.Markdown('<div class="footer">Model: law-ai/InLegalTrans-En2Indic-1B 路 Tokenizer: ai4bharat/indictrans2-en-indic-1B</div>')
201
 
202
+ # Keep queue to enable buffering; omit unsupported args on older Gradio
203
  demo.queue(max_size=48).launch()