Harisanth commited on
Commit
58feb81
·
verified ·
1 Parent(s): 2ab5145

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ model_map = {
6
+ "Tamil": "Harisanth/mbart-chatbot-tamil",
7
+ "Sinhala": "Harisanth/mbart-chatbot-sinhala",
8
+ "English": "Harisanth/mbart-chatbot-english",
9
+ "Tanglish": "Harisanth/mbart-chatbot-tanglish"
10
+ }
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ def chat_fn(text, lang):
15
+ repo = model_map[lang]
16
+ tok = AutoTokenizer.from_pretrained(repo)
17
+ mod = AutoModelForCausalLM.from_pretrained(repo).to(device)
18
+ tok.src_lang = {'Tamil':'ta_IN','Sinhala':'si_LK','English':'en_XX','Tanglish':'en_XX'}[lang]
19
+ inp = tok(text, return_tensors="pt").to(device)
20
+ out = mod.generate(**inp, max_length=100, forced_bos_token_id=tok.lang_code_to_id[tok.src_lang])
21
+ return tok.decode(out[0], skip_special_tokens=True)
22
+
23
+ iface = gr.Interface(
24
+ fn=chat_fn,
25
+ inputs=["text", gr.Radio(["Tamil","Sinhala","English","Tanglish"], label="Language")],
26
+ outputs="text",
27
+ title="Multilingual Chatbot",
28
+ description="A fine-tuned mBART chatbot by Harisanth"
29
+ )
30
+
31
+ if __name__ == "__main__":
32
+ iface.launch()