File size: 4,728 Bytes
724838e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import ctranslate2
import transformers
import os

MODEL_DIR = "models"
TOKENIZER_DIR = "models" # Relative path for HF Space compatibility

# Check if optimized model exists, else fallback or warn
if not os.path.exists(MODEL_DIR):
    print("Warning: CT2 Model not found. Please run src/optimize.py")

# Load Global resources
def load_model():
    global translator, tokenizer
    try:
        # 1. Try to load CTranslate2 model (Optimized Local)
        if os.path.exists(os.path.join(MODEL_DIR, "model.bin")):
            print("Loading CTranslate2 model from local storage...")
            translator = ctranslate2.Translator(MODEL_DIR)
            tokenizer = transformers.MBart50TokenizerFast.from_pretrained(TOKENIZER_DIR)
        
        # 2. Fallback: Load from Hugging Face Hub
        else:
            print("Local weights not found. Downloading fallback model from HF Hub (facebook/mbart-large-50)...")
            from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
            base_model_id = "facebook/mbart-large-50-many-to-many-mmt"
            tokenizer = MBart50TokenizerFast.from_pretrained(base_model_id)
            hf_model = MBartForConditionalGeneration.from_pretrained(base_model_id)
            
            # Create a simple wrapper to make hf_model act like a CT2 translator for the existing code
            class TransformersWrapper:
                def __init__(self, model, tokenizer):
                    self.model = model
                    self.tokenizer = tokenizer
                def translate_batch(self, source_tokens, target_prefix):
                    # Convert tokens back to text for transformers
                    text = [self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s)) for s in source_tokens]
                    encoded = self.tokenizer(text, return_tensors="pt", padding=True)
                    # Get target lang code
                    forced_bos_token_id = self.tokenizer.lang_code_to_id[target_prefix[0][0]]
                    generated_tokens = self.model.generate(
                        **encoded, 
                        forced_bos_token_id=forced_bos_token_id
                    )
                    # Wrap in a result object that mimics CT2 output
                    class Result:
                        def __init__(self, tokens): self.hypotheses = [tokens]
                    
                    return [Result(self.tokenizer.convert_ids_to_tokens(g)) for g in generated_tokens]
            
            translator = TransformersWrapper(hf_model, tokenizer)
            print("Fallback model loaded successfully.")

    except Exception as e:
        print(f"Error loading model: {e}")
        translator = None
        tokenizer = None

load_model()
if tokenizer:
    tokenizer.src_lang = "en_XX"

LANG_CODES = {
    "Hindi": "hi_IN",
    "Bengali": "bn_IN",
    "Tamil": "ta_IN"
}

def transliterate(text, target_language):
    if not translator or not text:
        return "Model not loaded or empty input."
    
    target_code = LANG_CODES.get(target_language)
    if not target_code:
        return "Invalid Language"

    # Tokenize
    source = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
    
    # Translate
    results = translator.translate_batch(
        [source],
        target_prefix=[[target_code]]
    )
    
    # Decode
    target = results[0].hypotheses[0]
    return tokenizer.decode(tokenizer.convert_tokens_to_ids(target), skip_special_tokens=True)

def create_demo():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# ๐ŸŒ Multilingual Transliteration Model")
        
        gr.Markdown("Transliterate English text to Hindi, Bengali, or Tamil.")
        
        with gr.Row():
            with gr.Column():
                input_text = gr.Textbox(label="Input Text (English/Roman)", placeholder="e.g. Namaste", lines=3)
                target_lang = gr.Dropdown(choices=["Hindi", "Bengali", "Tamil"], value="Hindi", label="Target Language")
                btn = gr.Button("๐Ÿš€ Transliterate", variant="primary")
            
            with gr.Column():
                output_text = gr.Textbox(label="Transliterated Output", lines=5)
        
        gr.Examples(
            examples=[
                ["Namaste", "Hindi"],
                ["Kemon achen", "Bengali"],
                ["Vanakkam", "Tamil"]
            ],
            inputs=[input_text, target_lang]
        )
        
        btn.click(fn=transliterate, inputs=[input_text, target_lang], outputs=output_text)
        
    return demo