Sabbir772 commited on
Commit
5216f5c
·
verified ·
1 Parent(s): 27d2c8c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import gradio as gr
4
+
5
+ # Load model
6
+ model_path = "./banglat5_bn_sy" # path inside Space
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+
13
+ # Translation function
14
+ def translate(text, source_lang):
15
+ if source_lang == "Bangla":
16
+ prefix = "<BN>"
17
+ elif source_lang == "Sylheti":
18
+ prefix = "<SY>"
19
+ else:
20
+ return "Invalid language selected."
21
+
22
+ input_text = f"{prefix} {text}"
23
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(device)
24
+
25
+ with torch.no_grad():
26
+ outputs = model.generate(
27
+ input_ids=inputs["input_ids"],
28
+ attention_mask=inputs["attention_mask"],
29
+ max_length=128,
30
+ num_beams=4,
31
+ early_stopping=True
32
+ )
33
+
34
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+
36
+ # Gradio interface
37
+ iface = gr.Interface(
38
+ fn=translate,
39
+ inputs=[
40
+ gr.Textbox(label="Input Text"),
41
+ gr.Radio(["Bangla", "Sylheti"], label="Source Language")
42
+ ],
43
+ outputs=gr.Textbox(label="Translated Text"),
44
+ title="Bangla ↔ Sylheti Dialect Translator (Fine-tuned T5)",
45
+ description="Translate between Bangla and Sylheti using a LoRA-finetuned Flan-T5 model."
46
+ )
47
+
48
+ iface.launch()