NetraVerse commited on
Commit
60ffa1d
Β·
verified Β·
1 Parent(s): 4aafe0f

Create the app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import gradio as gr
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+ from IndicTransToolkit.processor import IndicProcessor
6
+
7
+ # Get token from environment variable
8
+ token = os.getenv("HUGGINGFACE_HUB_TOKEN")
9
+
10
+ # Device configuration
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Model configuration - English to Kannada translation
14
+ src_lang, tgt_lang = "eng_Latn", "kan_Knda"
15
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M"
16
+
17
+ # Global variables to store model and tokenizer
18
+ model = None
19
+ tokenizer = None
20
+ ip = None
21
+
22
+ def load_model():
23
+ """Load the translation model and tokenizer"""
24
+ global model, tokenizer, ip
25
+
26
+ try:
27
+ print(f"Loading model: {model_name}")
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_name,
30
+ trust_remote_code=True,
31
+ token=token
32
+ )
33
+
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(
35
+ model_name,
36
+ trust_remote_code=True,
37
+ dtype=torch.float16,
38
+ token=token
39
+ ).to(DEVICE)
40
+
41
+ ip = IndicProcessor(inference=True)
42
+ print(f"Model loaded successfully on {DEVICE}")
43
+ return True
44
+
45
+ except Exception as e:
46
+ print(f"Error loading model: {str(e)}")
47
+ return False
48
+
49
+ def translate_text(input_text):
50
+ """
51
+ Translate input text using the loaded model
52
+
53
+ Args:
54
+ input_text: Single sentence to translate
55
+
56
+ Returns:
57
+ Translated text
58
+ """
59
+ if not model or not tokenizer or not ip:
60
+ return "❌ Model not loaded. Please check the model configuration."
61
+
62
+ if not input_text.strip():
63
+ return "Please enter some text to translate."
64
+
65
+ try:
66
+ # Single sentence translation
67
+ input_sentences = [input_text.strip()]
68
+
69
+ if not input_sentences:
70
+ return "No valid sentences found."
71
+
72
+ # Preprocess the input
73
+ batch = ip.preprocess_batch(
74
+ input_sentences,
75
+ src_lang=src_lang,
76
+ tgt_lang=tgt_lang,
77
+ )
78
+
79
+ # Tokenize the sentences
80
+ inputs = tokenizer(
81
+ batch,
82
+ truncation=True,
83
+ padding="longest",
84
+ return_tensors="pt",
85
+ return_attention_mask=True,
86
+ ).to(DEVICE)
87
+
88
+ # Generate translations
89
+ with torch.no_grad():
90
+ generated_tokens = model.generate(
91
+ **inputs,
92
+ use_cache=False,
93
+ min_length=0,
94
+ max_length=256,
95
+ num_beams=5,
96
+ num_return_sequences=1,
97
+ )
98
+
99
+ # Decode the generated tokens
100
+ generated_tokens = tokenizer.batch_decode(
101
+ generated_tokens,
102
+ skip_special_tokens=True,
103
+ clean_up_tokenization_spaces=True,
104
+ )
105
+
106
+ # Postprocess the translations
107
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
108
+
109
+ # Return single translation
110
+ return translations[0] if translations else "Translation failed."
111
+
112
+ except Exception as e:
113
+ return f"❌ Translation error: {str(e)}"
114
+
115
+ def create_interface():
116
+ """Create and configure the Gradio interface"""
117
+
118
+ # Load model on startup
119
+ model_loaded = load_model()
120
+
121
+ if not model_loaded:
122
+ # Create a simple error interface
123
+ with gr.Blocks(title="Translation App - Error") as demo:
124
+ gr.Markdown("## ❌ Model Loading Error")
125
+ gr.Markdown("Failed to load the translation model. Please check:")
126
+ gr.Markdown("- Your Hugging Face token is set correctly")
127
+ gr.Markdown("- You have access to the gated model")
128
+ gr.Markdown("- Your internet connection is working")
129
+ return demo
130
+
131
+ # Create the main interface
132
+ with gr.Blocks(
133
+ title="AI4Bharat IndicTrans2 Translation",
134
+ theme=gr.themes.Soft(),
135
+ ) as demo:
136
+
137
+ gr.Markdown(
138
+ f"""
139
+ # 🌍 AI4Bharat IndicTrans2 Translation
140
+
141
+ **Current Configuration:**
142
+ - **Source Language:** {src_lang} (English)
143
+ - **Target Language:** {tgt_lang} (Kannada)
144
+ - **Model:** {model_name}
145
+ - **Device:** {DEVICE}
146
+
147
+ Enter text below to translate from English to Kannada.
148
+ """)
149
+
150
+
151
+ with gr.Row():
152
+ with gr.Column():
153
+ input_text = gr.Textbox(
154
+ label=f"Input Text ({src_lang})",
155
+ placeholder="Enter English text to translate...",
156
+ lines=5,
157
+ max_lines=10
158
+ )
159
+
160
+ with gr.Row():
161
+ translate_btn = gr.Button("πŸ”„ Translate", variant="primary")
162
+ clear_btn = gr.Button("πŸ—‘οΈ Clear")
163
+
164
+ with gr.Column():
165
+ output_text = gr.Textbox(
166
+ label=f"Translation ({tgt_lang})",
167
+ lines=5,
168
+ max_lines=10,
169
+ interactive=False
170
+ )
171
+
172
+ # Example inputs
173
+ gr.Markdown("### πŸ“ Example Inputs:")
174
+ examples = [
175
+ ["Hello, how are you?"],
176
+ ["I am going to the market today."],
177
+ ["This is a very beautiful place."],
178
+ ["Can you help me?"],
179
+ ]
180
+
181
+ gr.Examples(
182
+ examples=examples,
183
+ inputs=[input_text],
184
+ outputs=[output_text],
185
+ fn=translate_text,
186
+ cache_examples=True
187
+ )
188
+
189
+ # Event handlers
190
+ translate_btn.click(
191
+ fn=translate_text,
192
+ inputs=[input_text],
193
+ outputs=[output_text]
194
+ )
195
+
196
+ clear_btn.click(
197
+ fn=lambda: ("", ""),
198
+ outputs=[input_text, output_text]
199
+ )
200
+
201
+ # Add footer
202
+ gr.Markdown("---")
203
+
204
+ return demo
205
+
206
+ if __name__ == "__main__":
207
+ # Create and launch the interface
208
+ demo = create_interface()
209
+
210
+ # Launch the app
211
+ demo.launch(
212
+ server_name="0.0.0.0", # Allow external connections
213
+ server_port=7860, # Default Gradio port
214
+ share=False, # Set to True if you want a public link
215
+ debug=True,
216
+ show_error=True
217
+ )