Fatitommy commited on
Commit
deec1f2
Β·
verified Β·
1 Parent(s): 3459a45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -15
app.py CHANGED
@@ -1,9 +1,9 @@
1
  """
2
  VoiceAura Translation API
3
  Models:
4
- 1. SLPG/English_to_Urdu_Unsupervised_MT (en β†’ ur)
5
- 2. SLPG/Punjabi_Shahmukhi_to_Gurmukhi (pa-s β†’ pa-g)
6
- 3. SLPG/Punjabi_Gurmukhi_to_Shahmukhi (pa-g β†’ pa-s)
7
  """
8
 
9
  from fastapi import FastAPI
@@ -13,6 +13,11 @@ import os, requests, argparse, torch
13
 
14
  # βœ… PyTorch 2.6 fix
15
  torch.serialization.add_safe_globals([argparse.Namespace])
 
 
 
 
 
16
 
17
  app = FastAPI()
18
  app.add_middleware(
@@ -22,7 +27,7 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # ── Model URLs ───────────────────────────────────────────
26
  MODELS_CONFIG = {
27
  "en-ur": {
28
  "files": {
@@ -32,6 +37,7 @@ MODELS_CONFIG = {
32
  },
33
  "dir": "models/en_ur",
34
  "checkpoint": "checkpoint_8_96000.pt",
 
35
  "instance": None,
36
  },
37
  "pa-s-pa-g": {
@@ -42,6 +48,18 @@ MODELS_CONFIG = {
42
  },
43
  "dir": "models/pa_s_pa_g",
44
  "checkpoint": "checkpoint_5_78000.pt",
 
 
 
 
 
 
 
 
 
 
 
 
45
  "instance": None,
46
  },
47
  }
@@ -60,22 +78,18 @@ def download_file(url: str, path: str):
60
  f.write(chunk)
61
  print(f"[βœ“] Done: {path}")
62
 
63
- def patched_torch_load(*args, **kwargs):
64
- kwargs["weights_only"] = False
65
- return _original_torch_load(*args, **kwargs)
66
-
67
- _original_torch_load = torch.load
68
 
69
  def load_model(pair: str):
70
  cfg = MODELS_CONFIG[pair]
71
  if cfg["instance"] is not None:
72
  return cfg["instance"]
73
 
74
- # Download files
75
  for fname, url in cfg["files"].items():
76
  download_file(url, os.path.join(cfg["dir"], fname))
77
 
78
- # Patch torch.load for fairseq
79
  torch.load = patched_torch_load
80
  from fairseq.models.transformer import TransformerModel
81
  model = TransformerModel.from_pretrained(
@@ -89,13 +103,13 @@ def load_model(pair: str):
89
  print(f"[βœ“] Model ready: {pair}")
90
  return model
91
 
92
- # ── Startup β€” load all models ────────────────────────────
93
  @app.on_event("startup")
94
  async def startup():
95
  for pair in MODELS_CONFIG:
96
  load_model(pair)
97
 
98
- # ── Endpoints ────────────────────────────────────────────
99
  class Req(BaseModel):
100
  text: str
101
  from_lang: str = "en"
@@ -104,7 +118,7 @@ class Req(BaseModel):
104
  @app.get("/")
105
  def root():
106
  loaded = {k: MODELS_CONFIG[k]["instance"] is not None for k in MODELS_CONFIG}
107
- return {"status": "VoiceAura API βœ“", "models": loaded}
108
 
109
  @app.post("/translate")
110
  def translate(req: Req):
@@ -117,9 +131,15 @@ def translate(req: Req):
117
  return {"success": False, "translation": f"⚠️ Pair '{pair}' not supported."}
118
 
119
  try:
 
120
  model = load_model(pair)
121
  result = model.translate(req.text.strip())
 
 
 
 
 
122
  return {"success": True, "translation": result, "pair": pair}
123
  except Exception as e:
124
  print(f"Error [{pair}]: {e}")
125
- return {"success": False, "translation": str(e)}
 
1
  """
2
  VoiceAura Translation API
3
  Models:
4
+ 1. SLPG/English_to_Urdu_Unsupervised_MT (en β†’ ur)
5
+ 2. SLPG/Punjabi_Shahmukhi_to_Gurmukhi_Transliteration (pa-s β†’ pa-g)
6
+ 3. SLPG/Punjabi_Gurmukhi_to_Shahmukhi_Transliteration (pa-g β†’ pa-s)
7
  """
8
 
9
  from fastapi import FastAPI
 
13
 
14
  # βœ… PyTorch 2.6 fix
15
  torch.serialization.add_safe_globals([argparse.Namespace])
16
+ _original_torch_load = torch.load
17
+
18
+ def patched_torch_load(*args, **kwargs):
19
+ kwargs["weights_only"] = False
20
+ return _original_torch_load(*args, **kwargs)
21
 
22
  app = FastAPI()
23
  app.add_middleware(
 
27
  allow_headers=["*"],
28
  )
29
 
30
+ # ── Model configs ────────────────────────────────────────
31
  MODELS_CONFIG = {
32
  "en-ur": {
33
  "files": {
 
37
  },
38
  "dir": "models/en_ur",
39
  "checkpoint": "checkpoint_8_96000.pt",
40
+ "detokenize": False,
41
  "instance": None,
42
  },
43
  "pa-s-pa-g": {
 
48
  },
49
  "dir": "models/pa_s_pa_g",
50
  "checkpoint": "checkpoint_5_78000.pt",
51
+ "detokenize": True,
52
+ "instance": None,
53
+ },
54
+ "pa-g-pa-s": {
55
+ "files": {
56
+ "checkpoint_13_129000.pt": "https://huggingface.co/SLPG/Punjabi_Gurmukhi_to_Shahmukhi_Transliteration/resolve/main/checkpoint_13_129000.pt",
57
+ "dict.pa.txt": "https://huggingface.co/SLPG/Punjabi_Gurmukhi_to_Shahmukhi_Transliteration/resolve/main/dict.pa.txt",
58
+ "dict.pk.txt": "https://huggingface.co/SLPG/Punjabi_Gurmukhi_to_Shahmukhi_Transliteration/resolve/main/dict.pk.txt",
59
+ },
60
+ "dir": "models/pa_g_pa_s",
61
+ "checkpoint": "checkpoint_13_129000.pt",
62
+ "detokenize": True,
63
  "instance": None,
64
  },
65
  }
 
78
  f.write(chunk)
79
  print(f"[βœ“] Done: {path}")
80
 
81
+ def detokenize(text: str) -> str:
82
+ """Remove fairseq BPE tokens (▁ symbols)"""
83
+ return text.replace("▁", "").strip()
 
 
84
 
85
  def load_model(pair: str):
86
  cfg = MODELS_CONFIG[pair]
87
  if cfg["instance"] is not None:
88
  return cfg["instance"]
89
 
 
90
  for fname, url in cfg["files"].items():
91
  download_file(url, os.path.join(cfg["dir"], fname))
92
 
 
93
  torch.load = patched_torch_load
94
  from fairseq.models.transformer import TransformerModel
95
  model = TransformerModel.from_pretrained(
 
103
  print(f"[βœ“] Model ready: {pair}")
104
  return model
105
 
106
+ # ── Startup ──────────────────────────────────────────────
107
  @app.on_event("startup")
108
  async def startup():
109
  for pair in MODELS_CONFIG:
110
  load_model(pair)
111
 
112
+ # ── API ──────────────────────────────────────────────────
113
  class Req(BaseModel):
114
  text: str
115
  from_lang: str = "en"
 
118
  @app.get("/")
119
  def root():
120
  loaded = {k: MODELS_CONFIG[k]["instance"] is not None for k in MODELS_CONFIG}
121
+ return {"status": "VoiceAura API βœ“", "models_loaded": loaded}
122
 
123
  @app.post("/translate")
124
  def translate(req: Req):
 
131
  return {"success": False, "translation": f"⚠️ Pair '{pair}' not supported."}
132
 
133
  try:
134
+ cfg = MODELS_CONFIG[pair]
135
  model = load_model(pair)
136
  result = model.translate(req.text.strip())
137
+
138
+ # Detokenize if needed (Punjabi models)
139
+ if cfg["detokenize"]:
140
+ result = detokenize(result)
141
+
142
  return {"success": True, "translation": result, "pair": pair}
143
  except Exception as e:
144
  print(f"Error [{pair}]: {e}")
145
+ return {"success": False, "translation": str(e)}