Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| model_map = { | |
| "Tamil": "Harisanth/mbart-chatbot-tamil", | |
| "Sinhala": "Harisanth/mbart-chatbot-sinhala", | |
| "English": "Harisanth/mbart-chatbot-english", | |
| "Tanglish": "Harisanth/mbart-chatbot-tanglish" | |
| } | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def chat_fn(text, lang): | |
| repo = model_map[lang] | |
| tok = AutoTokenizer.from_pretrained(repo) | |
| mod = AutoModelForCausalLM.from_pretrained(repo).to(device) | |
| tok.src_lang = {'Tamil':'ta_IN','Sinhala':'si_LK','English':'en_XX','Tanglish':'en_XX'}[lang] | |
| inp = tok(text, return_tensors="pt").to(device) | |
| out = mod.generate(**inp, max_length=100, forced_bos_token_id=tok.lang_code_to_id[tok.src_lang]) | |
| return tok.decode(out[0], skip_special_tokens=True) | |
| iface = gr.Interface( | |
| fn=chat_fn, | |
| inputs=["text", gr.Radio(["Tamil","Sinhala","English","Tanglish"], label="Language")], | |
| outputs="text", | |
| title="Multilingual Chatbot", | |
| description="A fine-tuned mBART chatbot by Harisanth" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |