nivra-ai-agent / indicTrans3Infer.py
datdevsteve's picture
Update indicTrans3Infer.py
d57fadf verified
# ==========================================================
# IndicTrans3 Text Translation Inference
# (Stable wrapper for HF Space usage)
# ==========================================================
from gradio_client import Client
import threading
# ----------------------------------------------------------
# Lazy singleton client (HF Spaces cold-start safe)
# ----------------------------------------------------------
_client = None
_client_lock = threading.Lock()
def get_indictrans_client():
global _client
with _client_lock:
if _client is None:
_client = Client("ai4bharat/IndicTrans3-beta")
return _client
# ----------------------------------------------------------
# BCP-47 → IndicTrans language labels
# ----------------------------------------------------------
LANGUAGE_MAP = {
"hi": "Hindi",
"bn": "Bengali",
"te": "Telugu",
"mr": "Marathi",
"ta": "Tamil",
"gu": "Gujarati",
"kn": "Kannada",
"ml": "Malayalam",
"pa": "Punjabi",
"or": "Odia",
"ur": "Urdu",
"en": "English",
}
def normalize_lang(code: str) -> str:
"""
Converts BCP-47 language codes:
hi-IN → hi
ta-IN → ta
"""
return code.split("-")[0].lower()
# ----------------------------------------------------------
# PUBLIC TRANSLATION FUNCTION (USED BY app.py)
# ----------------------------------------------------------
def translate_text(text: str, target_lang_code: str) -> str:
"""
Translate text using IndicTrans3.
- Fully stateless-safe
- Uses VALID chatbot history format
- Falls back gracefully on ANY failure
"""
# Guard: empty text
if not text or not text.strip():
return text
iso = normalize_lang(target_lang_code)
# Guard: English or unsupported language
if iso == "en" or iso not in LANGUAGE_MAP:
return text
target_lang = LANGUAGE_MAP[iso]
try:
client = get_indictrans_client()
# 🔑 CRITICAL FIX:
# IndicTrans3 REQUIRES chat history to be
# a list of [user, assistant] pairs
safe_history = [["", ""]]
# Use the STABLE endpoint
result = client.predict(
user_message=text,
history=safe_history,
target_lang=target_lang,
api_name="/user",
)
# result = (translated_text, updated_history)
translated_text = result[0]
if not translated_text or not translated_text.strip():
return text
return translated_text.strip()
except Exception as e:
print(f"❌ IndicTrans3 translation error: {e}")
return text