Project / app.py
Abhishek11k's picture
Upload 31 files
724838e verified
import gradio as gr
import ctranslate2
import transformers
import os
MODEL_DIR = "models"
TOKENIZER_DIR = "models" # Relative path for HF Space compatibility
# Check if optimized model exists, else fallback or warn
if not os.path.exists(MODEL_DIR):
print("Warning: CT2 Model not found. Please run src/optimize.py")
# Load Global resources
def load_model():
global translator, tokenizer
try:
# 1. Try to load CTranslate2 model (Optimized Local)
if os.path.exists(os.path.join(MODEL_DIR, "model.bin")):
print("Loading CTranslate2 model from local storage...")
translator = ctranslate2.Translator(MODEL_DIR)
tokenizer = transformers.MBart50TokenizerFast.from_pretrained(TOKENIZER_DIR)
# 2. Fallback: Load from Hugging Face Hub
else:
print("Local weights not found. Downloading fallback model from HF Hub (facebook/mbart-large-50)...")
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
base_model_id = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50TokenizerFast.from_pretrained(base_model_id)
hf_model = MBartForConditionalGeneration.from_pretrained(base_model_id)
# Create a simple wrapper to make hf_model act like a CT2 translator for the existing code
class TransformersWrapper:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def translate_batch(self, source_tokens, target_prefix):
# Convert tokens back to text for transformers
text = [self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s)) for s in source_tokens]
encoded = self.tokenizer(text, return_tensors="pt", padding=True)
# Get target lang code
forced_bos_token_id = self.tokenizer.lang_code_to_id[target_prefix[0][0]]
generated_tokens = self.model.generate(
**encoded,
forced_bos_token_id=forced_bos_token_id
)
# Wrap in a result object that mimics CT2 output
class Result:
def __init__(self, tokens): self.hypotheses = [tokens]
return [Result(self.tokenizer.convert_ids_to_tokens(g)) for g in generated_tokens]
translator = TransformersWrapper(hf_model, tokenizer)
print("Fallback model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
translator = None
tokenizer = None
load_model()
if tokenizer:
tokenizer.src_lang = "en_XX"
LANG_CODES = {
"Hindi": "hi_IN",
"Bengali": "bn_IN",
"Tamil": "ta_IN"
}
def transliterate(text, target_language):
if not translator or not text:
return "Model not loaded or empty input."
target_code = LANG_CODES.get(target_language)
if not target_code:
return "Invalid Language"
# Tokenize
source = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
# Translate
results = translator.translate_batch(
[source],
target_prefix=[[target_code]]
)
# Decode
target = results[0].hypotheses[0]
return tokenizer.decode(tokenizer.convert_tokens_to_ids(target), skip_special_tokens=True)
def create_demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🌐 Multilingual Transliteration Model")
gr.Markdown("Transliterate English text to Hindi, Bengali, or Tamil.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text (English/Roman)", placeholder="e.g. Namaste", lines=3)
target_lang = gr.Dropdown(choices=["Hindi", "Bengali", "Tamil"], value="Hindi", label="Target Language")
btn = gr.Button("πŸš€ Transliterate", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Transliterated Output", lines=5)
gr.Examples(
examples=[
["Namaste", "Hindi"],
["Kemon achen", "Bengali"],
["Vanakkam", "Tamil"]
],
inputs=[input_text, target_lang]
)
btn.click(fn=transliterate, inputs=[input_text, target_lang], outputs=output_text)
return demo