fereen5 commited on
Commit
28b803a
·
verified ·
1 Parent(s): c9ce27e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import os
4
+ from datetime import datetime
5
+ from functools import lru_cache
6
+ import torch
7
+
8
+ # Language codes
9
+ LANGUAGE_CODES = {
10
+ "English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
11
+ "Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
12
+ "Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
13
+ }
14
+
15
+ # Translation history class
16
+ class TranslationHistory:
17
+ def __init__(self):
18
+ self.history = []
19
+
20
+ def add(self, src, translated, src_lang, tgt_lang):
21
+ self.history.insert(0, {
22
+ "source": src, "translated": translated,
23
+ "src_lang": src_lang, "tgt_lang": tgt_lang,
24
+ "timestamp": datetime.now().isoformat()
25
+ })
26
+ if len(self.history) > 100:
27
+ self.history.pop()
28
+
29
+ def get(self):
30
+ return self.history
31
+
32
+ def clear(self):
33
+ self.history = []
34
+
35
+ # Initialize history
36
+ history = TranslationHistory()
37
+
38
+ # Load model and tokenizer
39
+ model_name = "facebook/nllb-200-distilled-600M"
40
+ try:
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ model.to(device)
45
+ except Exception as e:
46
+ raise RuntimeError(f"Failed to load model: {e}")
47
+
48
+ # Cache translation
49
+ @lru_cache(maxsize=512)
50
+ def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
51
+ if not text.strip():
52
+ return ""
53
+ try:
54
+ src_code = LANGUAGE_CODES.get(src_lang, src_lang)
55
+ tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang)
56
+
57
+ input_tokens = tokenizer(text, return_tensors="pt", padding=True)
58
+ input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
59
+
60
+ forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
61
+
62
+ output = model.generate(
63
+ **input_tokens,
64
+ forced_bos_token_id=forced_bos_token_id,
65
+ max_length=max_length,
66
+ temperature=temperature,
67
+ num_beams=5,
68
+ early_stopping=True
69
+ )
70
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
71
+ history.add(text, result, src_lang, tgt_lang)
72
+ return result
73
+ except Exception as e:
74
+ return f"Translation error: {e}"
75
+
76
+ # Swap languages
77
+ swap_langs = lambda src, tgt: (tgt, src)
78
+
79
+ # Translate file
80
+ def translate_file(file, src_lang, tgt_lang, max_length, temperature):
81
+ try:
82
+ lines = file.decode("utf-8").splitlines()
83
+ translated = [cached_translate(line, src_lang, tgt_lang, max_length, temperature) for line in lines if line.strip()]
84
+ return "\n".join(translated)
85
+ except Exception as e:
86
+ return f"File translation error: {e}"
87
+
88
+ # Custom CSS to improve UI
89
+ gradio_style = """
90
+ .gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; }
91
+ textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; }
92
+ textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; }
93
+ """
94
+
95
+ with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
96
+ gr.Markdown("""
97
+ # 🌍 PolyLinguaAI: Translate Across Worlds
98
+ Translate instantly between 12+ languages using Facebook's NLLB model.
99
+ """)
100
+
101
+ with gr.Tab("🌐 Text Translator"):
102
+ with gr.Row():
103
+ src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🌐 From", value="English")
104
+ swap = gr.Button("⇄")
105
+ tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🎯 To", value="Korean")
106
+
107
+ with gr.Row():
108
+ input_text = gr.Textbox(lines=3, label="✍️ Input Text")
109
+ output_text = gr.Textbox(lines=3, label="📤 Translated Output", interactive=False)
110
+
111
+ with gr.Row():
112
+ translate = gr.Button("🚀 Translate", variant="primary")
113
+ clear = gr.Button("🧽 Clear")
114
+
115
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
116
+ max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
117
+ temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
118
+
119
+ with gr.Accordion("📜 Translation History", open=False):
120
+ history_json = gr.JSON(label="Recent Translations")
121
+ with gr.Row():
122
+ refresh = gr.Button("🔄 Refresh")
123
+ clear_history = gr.Button("🧹 Clear History")
124
+
125
+ with gr.Tab("📁 File Translator"):
126
+ file_input = gr.File(label="📂 Upload .txt File", file_types=[".txt"])
127
+ file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 From", value="English")
128
+ file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 To", value="Korean")
129
+ file_translate = gr.Button("📄 Translate File", variant="primary")
130
+ file_result = gr.Textbox(label="📑 File Output", lines=10, interactive=False)
131
+
132
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
133
+ f_max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
134
+ f_temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
135
+
136
+ # Events
137
+ translate.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text)
138
+ clear.click(lambda: ("", ""), None, [input_text, output_text])
139
+ swap.click(swap_langs, [src_lang, tgt_lang], [src_lang, tgt_lang])
140
+ refresh.click(lambda: history.get(), None, history_json)
141
+ clear_history.click(lambda: history.clear() or [], None, history_json)
142
+ file_translate.click(lambda file, src, tgt, ml, t: translate_file(file.read(), src, tgt, ml, t),
143
+ [file_input, file_src, file_tgt, f_max_length, f_temp], file_result)
144
+
145
+ gr.Markdown(f"""
146
+ ### 🛠 Model Info
147
+ - Model: `{model_name}`
148
+ - Device: `{device}`
149
+ - Cached Translations: 512
150
+ """)
151
+
152
+ if __name__ == "__main__":
153
+ demo.launch(share=True)