import spaces import gradio as gr import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Model configuration model_name = "ai4bharat/IndicBART" # Load tokenizer and model on CPU print("Loading IndicBART tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True) print("Loading IndicBART model on CPU...") model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cpu") # Language mapping LANGUAGE_CODES = { "Assamese": "<2as>", "Bengali": "<2bn>", "English": "<2en>", "Gujarati": "<2gu>", "Hindi": "<2hi>", "Kannada": "<2kn>", "Malayalam": "<2ml>", "Marathi": "<2mr>", "Oriya": "<2or>", "Punjabi": "<2pa>", "Tamil": "<2ta>", "Telugu": "<2te>" } @spaces.GPU(duration=60) def generate_response(input_text, source_lang, target_lang, task_type, max_length): """Generate response using IndicBART""" device = "cuda" if torch.cuda.is_available() else "cpu" model_gpu = model.to(device) # Get language codes src_code = LANGUAGE_CODES[source_lang] tgt_code = LANGUAGE_CODES[target_lang] # Format input based on task type if task_type == "Translation": formatted_input = f"{input_text} {src_code}" decoder_start_token = tgt_code elif task_type == "Text Completion": # For completion, use target language formatted_input = f"{input_text} {tgt_code}" decoder_start_token = tgt_code else: # Text Generation formatted_input = f"{input_text} {src_code}" decoder_start_token = tgt_code # Tokenize input inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512) inputs = {k: v.to(device) for k, v in inputs.items()} # Get decoder start token id decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token) # Generate with torch.no_grad(): outputs = model_gpu.generate( **inputs, decoder_start_token_id=decoder_start_token_id, max_length=max_length, num_beams=4, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True ) # Decode output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) # Move model back to CPU model_gpu.cpu() torch.cuda.empty_cache() return generated_text # Create Gradio interface with gr.Blocks(title="IndicBART Multilingual Assistant", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ЁЯЗоЁЯЗ│ IndicBART Multilingual Assistant Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation. **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English """) with gr.Row(): with gr.Column(scale=3): input_text = gr.Textbox( label="Input Text", placeholder="Enter text in any supported language...", lines=3 ) output_text = gr.Textbox( label="Generated Output", lines=5, interactive=False ) generate_btn = gr.Button("Generate", variant="primary", size="lg") with gr.Column(scale=1): task_type = gr.Dropdown( choices=["Translation", "Text Completion", "Text Generation"], value="Translation", label="Task Type" ) source_lang = gr.Dropdown( choices=list(LANGUAGE_CODES.keys()), value="English", label="Source Language" ) target_lang = gr.Dropdown( choices=list(LANGUAGE_CODES.keys()), value="Hindi", label="Target Language" ) max_length = gr.Slider( minimum=50, maximum=300, value=100, step=10, label="Max Length" ) # Examples gr.Markdown("### ЁЯТб Try these examples:") examples = [ ["Hello, how are you?", "English", "Hindi", "Translation", 100], ["рдореИрдВ рдПрдХ рдЫрд╛рддреНрд░ рд╣реВрдВ", "Hindi", "English", "Translation", 100], ["ржЖржорж┐ ржнрж╛ржд ржЦрж╛ржЗ", "Bengali", "English", "Translation", 100], ["рднрд╛рд░рдд рдПрдХ", "Hindi", "Hindi", "Text Completion", 150], ["The capital of India", "English", "English", "Text Completion", 100] ] gr.Examples( examples=examples, inputs=[input_text, source_lang, target_lang, task_type, max_length], outputs=output_text, fn=generate_response ) # Connect generate button generate_btn.click( generate_response, inputs=[input_text, source_lang, target_lang, task_type, max_length], outputs=output_text ) if __name__ == "__main__": demo.launch()