Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| # Model configuration | |
| model_name = "ai4bharat/IndicBART" | |
| # Load tokenizer and model on CPU | |
| print("Loading IndicBART tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True) | |
| print("Loading IndicBART model on CPU...") | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cpu") | |
| # Language mapping | |
| LANGUAGE_CODES = { | |
| "Assamese": "<2as>", | |
| "Bengali": "<2bn>", | |
| "English": "<2en>", | |
| "Gujarati": "<2gu>", | |
| "Hindi": "<2hi>", | |
| "Kannada": "<2kn>", | |
| "Malayalam": "<2ml>", | |
| "Marathi": "<2mr>", | |
| "Oriya": "<2or>", | |
| "Punjabi": "<2pa>", | |
| "Tamil": "<2ta>", | |
| "Telugu": "<2te>" | |
| } | |
| def generate_response(input_text, source_lang, target_lang, task_type, max_length): | |
| """Generate response using IndicBART""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_gpu = model.to(device) | |
| # Get language codes | |
| src_code = LANGUAGE_CODES[source_lang] | |
| tgt_code = LANGUAGE_CODES[target_lang] | |
| # Format input based on task type | |
| if task_type == "Translation": | |
| formatted_input = f"{input_text} </s> {src_code}" | |
| decoder_start_token = tgt_code | |
| elif task_type == "Text Completion": | |
| # For completion, use target language | |
| formatted_input = f"{input_text} </s> {tgt_code}" | |
| decoder_start_token = tgt_code | |
| else: # Text Generation | |
| formatted_input = f"{input_text} </s> {src_code}" | |
| decoder_start_token = tgt_code | |
| # Tokenize input | |
| inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Get decoder start token id | |
| decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model_gpu.generate( | |
| **inputs, | |
| decoder_start_token_id=decoder_start_token_id, | |
| max_length=max_length, | |
| num_beams=4, | |
| early_stopping=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| use_cache=True | |
| ) | |
| # Decode output | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| # Move model back to CPU | |
| model_gpu.cpu() | |
| torch.cuda.empty_cache() | |
| return generated_text | |
| # Create Gradio interface | |
| with gr.Blocks(title="IndicBART Multilingual Assistant", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🇮🇳 IndicBART Multilingual Assistant | |
| Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation. | |
| **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| input_text = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter text in any supported language...", | |
| lines=3 | |
| ) | |
| output_text = gr.Textbox( | |
| label="Generated Output", | |
| lines=5, | |
| interactive=False | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| task_type = gr.Dropdown( | |
| choices=["Translation", "Text Completion", "Text Generation"], | |
| value="Translation", | |
| label="Task Type" | |
| ) | |
| source_lang = gr.Dropdown( | |
| choices=list(LANGUAGE_CODES.keys()), | |
| value="English", | |
| label="Source Language" | |
| ) | |
| target_lang = gr.Dropdown( | |
| choices=list(LANGUAGE_CODES.keys()), | |
| value="Hindi", | |
| label="Target Language" | |
| ) | |
| max_length = gr.Slider( | |
| minimum=50, | |
| maximum=300, | |
| value=100, | |
| step=10, | |
| label="Max Length" | |
| ) | |
| # Examples | |
| gr.Markdown("### 💡 Try these examples:") | |
| examples = [ | |
| ["Hello, how are you?", "English", "Hindi", "Translation", 100], | |
| ["मैं एक छात्र हूं", "Hindi", "English", "Translation", 100], | |
| ["আমি ভাত খাই", "Bengali", "English", "Translation", 100], | |
| ["भारत एक", "Hindi", "Hindi", "Text Completion", 150], | |
| ["The capital of India", "English", "English", "Text Completion", 100] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_text, source_lang, target_lang, task_type, max_length], | |
| outputs=output_text, | |
| fn=generate_response | |
| ) | |
| # Connect generate button | |
| generate_btn.click( | |
| generate_response, | |
| inputs=[input_text, source_lang, target_lang, task_type, max_length], | |
| outputs=output_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |