Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import nltk | |
| nltk.download('punkt_tab') | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from IndicTransToolkit import IndicProcessor | |
| import torch | |
| # Load IndicTrans2 model | |
| model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
| ip = IndicProcessor(inference=True) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(DEVICE) | |
| def split_text_into_batches(text, max_tokens_per_batch): | |
| sentences = nltk.sent_tokenize(text) # Tokenize text into sentences | |
| batches = [] | |
| current_batch = "" | |
| for sentence in sentences: | |
| if len(current_batch) + len(sentence) + 1 <= max_tokens_per_batch: # Add 1 for space | |
| current_batch += sentence + " " # Add sentence to current batch | |
| else: | |
| batches.append(current_batch.strip()) # Add current batch to batches list | |
| current_batch = sentence + " " # Start a new batch with the current sentence | |
| if current_batch: | |
| batches.append(current_batch.strip()) # Add the last batch | |
| return batches | |
| def run_translation(file_uploader, input_text, source_language, target_language): | |
| if file_uploader is not None: | |
| with open(file_uploader.name, "r", encoding="utf-8") as file: | |
| input_text = file.read() | |
| # Language mapping | |
| lang_code_map = { | |
| "Hindi": "hin_Deva", | |
| "Punjabi": "pan_Guru", | |
| "English": "eng_Latn", | |
| } | |
| src_lang = lang_code_map[source_language] | |
| tgt_lang = lang_code_map[target_language] | |
| max_tokens_per_batch = 256 | |
| batches = split_text_into_batches(input_text, max_tokens_per_batch) | |
| translated_text = "" | |
| for batch in batches: | |
| batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang) | |
| inputs = tokenizer( | |
| batch_preprocessed, | |
| truncation=True, | |
| padding="longest", | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_tokens = model.generate( | |
| **inputs, | |
| use_cache=True, | |
| min_length=0, | |
| max_length=256, | |
| num_beams=5, | |
| num_return_sequences=1, | |
| ) | |
| with tokenizer.as_target_tokenizer(): | |
| decoded_tokens = tokenizer.batch_decode( | |
| generated_tokens.detach().cpu().tolist(), | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang) | |
| translated_text += " ".join(translations) + " " | |
| output = translated_text.strip() | |
| _output_name = "result.txt" | |
| with open(_output_name, "w", encoding="utf-8") as out_file: | |
| out_file.write(output) | |
| return output, _output_name | |
| # Define Gradio UI | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_uploader = gr.File(label="Upload a text file (Optional)") | |
| input_text = gr.Textbox(label="Input text", lines=5, placeholder="Enter text here...") | |
| source_language = gr.Dropdown( | |
| label="Source language", | |
| choices=["Hindi", "Punjabi", "English"], | |
| value="Hindi", | |
| ) | |
| target_language = gr.Dropdown( | |
| label="Target language", | |
| choices=["Hindi", "Punjabi", "English"], | |
| value="English", | |
| ) | |
| btn = gr.Button("Translate") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Translated text", lines=5) | |
| output_file = gr.File(label="Translated text file") | |
| btn.click( | |
| fn=run_translation, | |
| inputs=[file_uploader, input_text, source_language, target_language], | |
| outputs=[output_text, output_file], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |