Project-1 / app.py
Abhishek11k's picture
Upload 31 files
e1d9ec2 verified
import gradio as gr
import ctranslate2
import transformers
import os
MODEL_DIR = "models"
TOKENIZER_DIR = "models" # Relative path for HF Space compatibility
# Check if optimized model exists, else fallback or warn
if not os.path.exists(MODEL_DIR):
print("Warning: CT2 Model not found. Please run src/optimize.py")
# Load Global resources
def load_model():
global translator, tokenizer
try:
# 1. Try to load custom CTranslate2 model (Optimized Local)
if os.path.exists(os.path.join(MODEL_DIR, "model.bin")):
print(f"Loading custom CTranslate2 model from {MODEL_DIR}...")
translator = ctranslate2.Translator(MODEL_DIR)
tokenizer = transformers.MBart50TokenizerFast.from_pretrained(TOKENIZER_DIR)
print("Custom CT2 model loaded successfully.")
# 2. Try to load custom Transformers model (Standard Local)
elif os.path.exists(os.path.join(MODEL_DIR, "pytorch_model.bin")) or \
os.path.exists(os.path.join(MODEL_DIR, "model.safetensors")):
print(f"Loading custom Transformers model from {MODEL_DIR}...")
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
tokenizer = MBart50TokenizerFast.from_pretrained(TOKENIZER_DIR)
hf_model = MBartForConditionalGeneration.from_pretrained(MODEL_DIR)
# Wrapper to make hf_model act like CT2
class TransformersWrapper:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def translate_batch(self, source_tokens, target_prefix):
text = [self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s)) for s in source_tokens]
encoded = self.tokenizer(text, return_tensors="pt", padding=True)
forced_bos_token_id = self.tokenizer.lang_code_to_id[target_prefix[0][0]]
generated_tokens = self.model.generate(**encoded, forced_bos_token_id=forced_bos_token_id)
class Result:
def __init__(self, tokens): self.hypotheses = [tokens]
return [Result(self.tokenizer.convert_ids_to_tokens(g)) for g in generated_tokens]
translator = TransformersWrapper(hf_model, tokenizer)
print("Custom Transformers model loaded successfully.")
else:
print(f"❌ Error: Custom model weights not found in '{MODEL_DIR}'.")
print("Please ensure 'model.bin' or 'pytorch_model.bin' is present in the models folder.")
translator = None
tokenizer = None
except Exception as e:
print(f"Error loading model: {e}")
translator = None
tokenizer = None
load_model()
if tokenizer:
tokenizer.src_lang = "en_XX"
LANG_CODES = {
"Hindi": "hi_IN",
"Bengali": "bn_IN",
"Tamil": "ta_IN"
}
def transliterate(text, target_language):
if not translator or not text:
return "Model not loaded or empty input."
target_code = LANG_CODES.get(target_language)
if not target_code:
return "Invalid Language"
# Tokenize
source = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
# Translate
results = translator.translate_batch(
[source],
target_prefix=[[target_code]]
)
# Decode
target = results[0].hypotheses[0]
return tokenizer.decode(tokenizer.convert_tokens_to_ids(target), skip_special_tokens=True)
def create_demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🌐 Multilingual Transliteration Model")
gr.Markdown("Transliterate English text to Hindi, Bengali, or Tamil.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text (English/Roman)", placeholder="e.g. Namaste", lines=3)
target_lang = gr.Dropdown(choices=["Hindi", "Bengali", "Tamil"], value="Hindi", label="Target Language")
btn = gr.Button("πŸš€ Transliterate", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Transliterated Output", lines=5)
gr.Examples(
examples=[
["Namaste", "Hindi"],
["Kemon achen", "Bengali"],
["Vanakkam", "Tamil"]
],
inputs=[input_text, target_lang]
)
btn.click(fn=transliterate, inputs=[input_text, target_lang], outputs=output_text)
return demo