File size: 6,431 Bytes
60ffa1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import torch
import os
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit.processor import IndicProcessor

# Get token from environment variable
token = os.getenv("HUGGINGFACE_HUB_TOKEN")

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model configuration - English to Kannada translation
src_lang, tgt_lang = "eng_Latn", "kan_Knda"
model_name = "ai4bharat/indictrans2-en-indic-dist-200M"

# Global variables to store model and tokenizer
model = None
tokenizer = None
ip = None

def load_model():
    """Load the translation model and tokenizer"""
    global model, tokenizer, ip
    
    try:
        print(f"Loading model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(
            model_name, 
            trust_remote_code=True,
            token=token
        )
        
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, 
            trust_remote_code=True, 
            dtype=torch.float16,
            token=token
        ).to(DEVICE)
        
        ip = IndicProcessor(inference=True)
        print(f"Model loaded successfully on {DEVICE}")
        return True
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return False

def translate_text(input_text):
    """
    Translate input text using the loaded model
    
    Args:
        input_text: Single sentence to translate
    
    Returns:
        Translated text
    """
    if not model or not tokenizer or not ip:
        return "❌ Model not loaded. Please check the model configuration."
    
    if not input_text.strip():
        return "Please enter some text to translate."
    
    try:
        # Single sentence translation
        input_sentences = [input_text.strip()]
        
        if not input_sentences:
            return "No valid sentences found."
        
        # Preprocess the input
        batch = ip.preprocess_batch(
            input_sentences,
            src_lang=src_lang,
            tgt_lang=tgt_lang,
        )
        
        # Tokenize the sentences
        inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)
        
        # Generate translations
        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                use_cache=False,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )
        
        # Decode the generated tokens
        generated_tokens = tokenizer.batch_decode(
            generated_tokens,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        )
        
        # Postprocess the translations
        translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
        
        # Return single translation
        return translations[0] if translations else "Translation failed."
            
    except Exception as e:
        return f"❌ Translation error: {str(e)}"

def create_interface():
    """Create and configure the Gradio interface"""
    
    # Load model on startup
    model_loaded = load_model()
    
    if not model_loaded:
        # Create a simple error interface
        with gr.Blocks(title="Translation App - Error") as demo:
            gr.Markdown("## ❌ Model Loading Error")
            gr.Markdown("Failed to load the translation model. Please check:")
            gr.Markdown("- Your Hugging Face token is set correctly")
            gr.Markdown("- You have access to the gated model")
            gr.Markdown("- Your internet connection is working")
        return demo
    
    # Create the main interface
    with gr.Blocks(
        title="AI4Bharat IndicTrans2 Translation",
        theme=gr.themes.Soft(),
    ) as demo:
        
        gr.Markdown(
            f"""
            # 🌍 AI4Bharat IndicTrans2 Translation
            
            **Current Configuration:**
            - **Source Language:** {src_lang} (English)
            - **Target Language:** {tgt_lang} (Kannada)
            - **Model:** {model_name}
            - **Device:** {DEVICE}
            
            Enter text below to translate from English to Kannada.
            """)
        
        
        with gr.Row():
            with gr.Column():
                input_text = gr.Textbox(
                    label=f"Input Text ({src_lang})",
                    placeholder="Enter English text to translate...",
                    lines=5,
                    max_lines=10
                )
                
                with gr.Row():
                    translate_btn = gr.Button("πŸ”„ Translate", variant="primary")
                    clear_btn = gr.Button("πŸ—‘οΈ Clear")
            
            with gr.Column():
                output_text = gr.Textbox(
                    label=f"Translation ({tgt_lang})",
                    lines=5,
                    max_lines=10,
                    interactive=False
                )
        
        # Example inputs
        gr.Markdown("### πŸ“ Example Inputs:")
        examples = [
            ["Hello, how are you?"],
            ["I am going to the market today."],
            ["This is a very beautiful place."],
            ["Can you help me?"],
        ]
        
        gr.Examples(
            examples=examples,
            inputs=[input_text],
            outputs=[output_text],
            fn=translate_text,
            cache_examples=True
        )
        
        # Event handlers
        translate_btn.click(
            fn=translate_text,
            inputs=[input_text],
            outputs=[output_text]
        )
        
        clear_btn.click(
            fn=lambda: ("", ""),
            outputs=[input_text, output_text]
        )
        
        # Add footer
        gr.Markdown("---")
    
    return demo

if __name__ == "__main__":
    # Create and launch the interface
    demo = create_interface()
    
    # Launch the app
    demo.launch(
        server_name="0.0.0.0",  # Allow external connections
        server_port=7860,       # Default Gradio port
        share=False,            # Set to True if you want a public link
        debug=True,
        show_error=True
    )