Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,23 +1,22 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 3 |
-
import
|
| 4 |
from datetime import datetime
|
| 5 |
from functools import lru_cache
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
# Translator Tool Components
|
| 9 |
-
# ---------------------------
|
| 10 |
-
|
| 11 |
LANGUAGE_CODES = {
|
| 12 |
"English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
|
| 13 |
"Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
|
| 14 |
"Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
|
| 15 |
}
|
| 16 |
|
|
|
|
| 17 |
class TranslationHistory:
|
| 18 |
def __init__(self):
|
| 19 |
self.history = []
|
| 20 |
-
|
| 21 |
def add(self, src, translated, src_lang, tgt_lang):
|
| 22 |
self.history.insert(0, {
|
| 23 |
"source": src, "translated": translated,
|
|
@@ -26,65 +25,123 @@ class TranslationHistory:
|
|
| 26 |
})
|
| 27 |
if len(self.history) > 100:
|
| 28 |
self.history.pop()
|
| 29 |
-
|
| 30 |
def get(self):
|
| 31 |
return self.history
|
| 32 |
-
|
| 33 |
def clear(self):
|
| 34 |
self.history = []
|
| 35 |
|
| 36 |
history = TranslationHistory()
|
| 37 |
|
|
|
|
| 38 |
model_name = "facebook/nllb-200-distilled-600M"
|
| 39 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 40 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 41 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 42 |
model.to(device)
|
| 43 |
|
|
|
|
| 44 |
@lru_cache(maxsize=512)
|
| 45 |
def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
|
| 46 |
-
if not text.strip():
|
| 47 |
-
return ""
|
| 48 |
src_code = LANGUAGE_CODES.get(src_lang, src_lang)
|
| 49 |
tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang)
|
| 50 |
input_tokens = tokenizer(text, return_tensors="pt", padding=True)
|
| 51 |
input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
|
| 52 |
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
|
| 53 |
-
output = model.generate(
|
| 54 |
-
**input_tokens,
|
| 55 |
forced_bos_token_id=forced_bos_token_id,
|
| 56 |
-
max_length=max_length,
|
| 57 |
-
|
| 58 |
-
num_beams=5,
|
| 59 |
-
early_stopping=True
|
| 60 |
)
|
| 61 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 62 |
history.add(text, result, src_lang, tgt_lang)
|
| 63 |
return result
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
translate_btn = gr.Button("Translate")
|
| 79 |
-
clear_btn = gr.Button("Clear")
|
| 80 |
-
translate_btn.click(cached_translate, [input_text, src_lang, tgt_lang], output_text)
|
| 81 |
-
clear_btn.click(lambda: ("", ""), None, [input_text, output_text])
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
demo.launch(share=
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 3 |
+
import os
|
| 4 |
from datetime import datetime
|
| 5 |
from functools import lru_cache
|
| 6 |
+
import torch
|
| 7 |
+
import requests
|
| 8 |
|
| 9 |
+
# Language Codes
|
|
|
|
|
|
|
|
|
|
| 10 |
LANGUAGE_CODES = {
|
| 11 |
"English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
|
| 12 |
"Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
|
| 13 |
"Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
|
| 14 |
}
|
| 15 |
|
| 16 |
+
# Translation History
|
| 17 |
class TranslationHistory:
|
| 18 |
def __init__(self):
|
| 19 |
self.history = []
|
|
|
|
| 20 |
def add(self, src, translated, src_lang, tgt_lang):
|
| 21 |
self.history.insert(0, {
|
| 22 |
"source": src, "translated": translated,
|
|
|
|
| 25 |
})
|
| 26 |
if len(self.history) > 100:
|
| 27 |
self.history.pop()
|
|
|
|
| 28 |
def get(self):
|
| 29 |
return self.history
|
|
|
|
| 30 |
def clear(self):
|
| 31 |
self.history = []
|
| 32 |
|
| 33 |
history = TranslationHistory()
|
| 34 |
|
| 35 |
+
# Load Translation Model
|
| 36 |
model_name = "facebook/nllb-200-distilled-600M"
|
| 37 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 38 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 39 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 40 |
model.to(device)
|
| 41 |
|
| 42 |
+
# Translation Function
|
| 43 |
@lru_cache(maxsize=512)
|
| 44 |
def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
|
| 45 |
+
if not text.strip(): return ""
|
|
|
|
| 46 |
src_code = LANGUAGE_CODES.get(src_lang, src_lang)
|
| 47 |
tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang)
|
| 48 |
input_tokens = tokenizer(text, return_tensors="pt", padding=True)
|
| 49 |
input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
|
| 50 |
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
|
| 51 |
+
output = model.generate(**input_tokens,
|
|
|
|
| 52 |
forced_bos_token_id=forced_bos_token_id,
|
| 53 |
+
max_length=max_length, temperature=temperature,
|
| 54 |
+
num_beams=5, early_stopping=True
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 57 |
history.add(text, result, src_lang, tgt_lang)
|
| 58 |
return result
|
| 59 |
|
| 60 |
+
def translate_file(file, src_lang, tgt_lang, max_length, temperature):
|
| 61 |
+
try:
|
| 62 |
+
lines = file.decode("utf-8").splitlines()
|
| 63 |
+
translated = [cached_translate(line, src_lang, tgt_lang, max_length, temperature) for line in lines if line.strip()]
|
| 64 |
+
return "\n".join(translated)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
return f"File translation error: {e}"
|
| 67 |
+
|
| 68 |
+
# Summarization
|
| 69 |
+
API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
|
| 70 |
+
HF_API_KEY = os.environ.get("HF_API_KEY", "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")
|
| 71 |
+
headers = {"Authorization": f"Bearer {HF_API_KEY}"}
|
| 72 |
+
|
| 73 |
+
def summarize_text(text, max_length):
|
| 74 |
+
if not text.strip(): return ""
|
| 75 |
+
min_length = max(10, max_length // 4)
|
| 76 |
+
response = requests.post(API_URL, headers=headers, json={
|
| 77 |
+
"inputs": text,
|
| 78 |
+
"parameters": {"min_length": min_length, "max_length": max_length}
|
| 79 |
+
})
|
| 80 |
+
result = response.json()
|
| 81 |
+
return result[0]["summary_text"] if isinstance(result, list) else "Error: " + str(result)
|
| 82 |
+
|
| 83 |
+
# UI Styling
|
| 84 |
+
gradio_style = """
|
| 85 |
+
.gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; }
|
| 86 |
+
textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; }
|
| 87 |
+
textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; }
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
|
| 91 |
+
gr.Markdown("## 🤖 AI Toolbox: Translate & Summarize")
|
| 92 |
+
|
| 93 |
+
with gr.Tab("🌐 Text Translator"):
|
| 94 |
+
with gr.Row():
|
| 95 |
+
src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🌐 From", value="English")
|
| 96 |
+
swap = gr.Button("⇄")
|
| 97 |
+
tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🎯 To", value="Korean")
|
| 98 |
+
with gr.Row():
|
| 99 |
+
input_text = gr.Textbox(lines=3, label="✍️ Input Text")
|
| 100 |
+
output_text = gr.Textbox(lines=3, label="📤 Translated Output", interactive=False)
|
| 101 |
+
with gr.Row():
|
| 102 |
+
translate = gr.Button("🚀 Translate", variant="primary")
|
| 103 |
+
clear = gr.Button("🧽 Clear")
|
| 104 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 105 |
+
max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
|
| 106 |
+
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
|
| 107 |
+
with gr.Accordion("📜 Translation History", open=False):
|
| 108 |
+
history_json = gr.JSON(label="Recent Translations")
|
| 109 |
+
with gr.Row():
|
| 110 |
+
refresh = gr.Button("🔄 Refresh")
|
| 111 |
+
clear_history = gr.Button("🧹 Clear History")
|
| 112 |
|
| 113 |
+
with gr.Tab("📁 File Translator"):
|
| 114 |
+
file_input = gr.File(label="📂 Upload .txt File", file_types=[".txt"])
|
| 115 |
+
file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 From", value="English")
|
| 116 |
+
file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 To", value="Korean")
|
| 117 |
+
file_translate = gr.Button("📄 Translate File", variant="primary")
|
| 118 |
+
file_result = gr.Textbox(label="📑 File Output", lines=10, interactive=False)
|
| 119 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 120 |
+
f_max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
|
| 121 |
+
f_temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
|
| 122 |
|
| 123 |
+
with gr.Tab("📝 Text Summarizer"):
|
| 124 |
+
summary_input = gr.Textbox(lines=5, label="📚 Enter text to summarize")
|
| 125 |
+
summary_length = gr.Slider(32, 512, value=128, step=8, label="📏 Max Length")
|
| 126 |
+
summary_output = gr.Textbox(label="🧾 Summary", lines=5, interactive=False)
|
| 127 |
+
summary_btn = gr.Button("🧠 Summarize")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
# Button events
|
| 130 |
+
translate.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text)
|
| 131 |
+
clear.click(lambda: ("", ""), None, [input_text, output_text])
|
| 132 |
+
swap.click(lambda s, t: (t, s), [src_lang, tgt_lang], [src_lang, tgt_lang])
|
| 133 |
+
refresh.click(lambda: history.get(), None, history_json)
|
| 134 |
+
clear_history.click(lambda: history.clear() or [], None, history_json)
|
| 135 |
+
file_translate.click(lambda file, src, tgt, ml, t: translate_file(file.read(), src, tgt, ml, t),
|
| 136 |
+
[file_input, file_src, file_tgt, f_max_length, f_temp], file_result)
|
| 137 |
+
summary_btn.click(summarize_text, [summary_input, summary_length], summary_output)
|
| 138 |
|
| 139 |
+
gr.Markdown(f"""
|
| 140 |
+
### 🔍 Info
|
| 141 |
+
- Translator Model: `{model_name}` on `{device}`
|
| 142 |
+
- Summarizer Model: `facebook/bart-large-cnn`
|
| 143 |
+
- HF API Key: {'Loaded ✅' if HF_API_KEY else 'Missing ❌'}
|
| 144 |
+
""")
|
| 145 |
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
demo.launch(share=True)
|