fereen5 commited on
Commit
652f4e8
·
verified ·
1 Parent(s): 6444b22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -37
app.py CHANGED
@@ -1,23 +1,22 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
- import torch
4
  from datetime import datetime
5
  from functools import lru_cache
 
 
6
 
7
- # ---------------------------
8
- # Translator Tool Components
9
- # ---------------------------
10
-
11
  LANGUAGE_CODES = {
12
  "English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
13
  "Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
14
  "Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
15
  }
16
 
 
17
  class TranslationHistory:
18
  def __init__(self):
19
  self.history = []
20
-
21
  def add(self, src, translated, src_lang, tgt_lang):
22
  self.history.insert(0, {
23
  "source": src, "translated": translated,
@@ -26,65 +25,123 @@ class TranslationHistory:
26
  })
27
  if len(self.history) > 100:
28
  self.history.pop()
29
-
30
  def get(self):
31
  return self.history
32
-
33
  def clear(self):
34
  self.history = []
35
 
36
  history = TranslationHistory()
37
 
 
38
  model_name = "facebook/nllb-200-distilled-600M"
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
  model.to(device)
43
 
 
44
  @lru_cache(maxsize=512)
45
  def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
46
- if not text.strip():
47
- return ""
48
  src_code = LANGUAGE_CODES.get(src_lang, src_lang)
49
  tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang)
50
  input_tokens = tokenizer(text, return_tensors="pt", padding=True)
51
  input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
52
  forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
53
- output = model.generate(
54
- **input_tokens,
55
  forced_bos_token_id=forced_bos_token_id,
56
- max_length=max_length,
57
- temperature=temperature,
58
- num_beams=5,
59
- early_stopping=True
60
  )
61
  result = tokenizer.decode(output[0], skip_special_tokens=True)
62
  history.add(text, result, src_lang, tgt_lang)
63
  return result
64
 
65
- # ---------------------------
66
- # Gradio App with Tabs
67
- # ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- def create_demo():
70
- with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
71
- gr.Markdown("## 🤖 Smart AI Tools – Translate, Summarize, and More")
 
 
 
 
 
 
72
 
73
- with gr.Tab("🌐 Translator"):
74
- src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="From", value="English")
75
- tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="To", value="Korean")
76
- input_text = gr.Textbox(label="Input Text", lines=4)
77
- output_text = gr.Textbox(label="Translated Output", lines=4, interactive=False)
78
- translate_btn = gr.Button("Translate")
79
- clear_btn = gr.Button("Clear")
80
- translate_btn.click(cached_translate, [input_text, src_lang, tgt_lang], output_text)
81
- clear_btn.click(lambda: ("", ""), None, [input_text, output_text])
82
 
83
- with gr.Tab("📝 Summarizer"):
84
- gr.Markdown("### This tool summarizes long content using `facebook/bart-large-cnn`")
85
- gr.load("models/facebook/bart-large-cnn", provider="huggingface", label="Summarizer")
 
 
 
 
 
 
86
 
87
- return demo
 
 
 
 
 
88
 
89
- demo = create_demo()
90
- demo.launch(share=False)
 
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import os
4
  from datetime import datetime
5
  from functools import lru_cache
6
+ import torch
7
+ import requests
8
 
9
+ # Language Codes
 
 
 
10
  LANGUAGE_CODES = {
11
  "English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
12
  "Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
13
  "Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
14
  }
15
 
16
+ # Translation History
17
  class TranslationHistory:
18
  def __init__(self):
19
  self.history = []
 
20
  def add(self, src, translated, src_lang, tgt_lang):
21
  self.history.insert(0, {
22
  "source": src, "translated": translated,
 
25
  })
26
  if len(self.history) > 100:
27
  self.history.pop()
 
28
  def get(self):
29
  return self.history
 
30
  def clear(self):
31
  self.history = []
32
 
33
  history = TranslationHistory()
34
 
35
+ # Load Translation Model
36
  model_name = "facebook/nllb-200-distilled-600M"
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
  model.to(device)
41
 
42
+ # Translation Function
43
  @lru_cache(maxsize=512)
44
  def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
45
+ if not text.strip(): return ""
 
46
  src_code = LANGUAGE_CODES.get(src_lang, src_lang)
47
  tgt_code = LANGUAGE_CODES.get(tgt_lang, tgt_lang)
48
  input_tokens = tokenizer(text, return_tensors="pt", padding=True)
49
  input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
50
  forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
51
+ output = model.generate(**input_tokens,
 
52
  forced_bos_token_id=forced_bos_token_id,
53
+ max_length=max_length, temperature=temperature,
54
+ num_beams=5, early_stopping=True
 
 
55
  )
56
  result = tokenizer.decode(output[0], skip_special_tokens=True)
57
  history.add(text, result, src_lang, tgt_lang)
58
  return result
59
 
60
+ def translate_file(file, src_lang, tgt_lang, max_length, temperature):
61
+ try:
62
+ lines = file.decode("utf-8").splitlines()
63
+ translated = [cached_translate(line, src_lang, tgt_lang, max_length, temperature) for line in lines if line.strip()]
64
+ return "\n".join(translated)
65
+ except Exception as e:
66
+ return f"File translation error: {e}"
67
+
68
+ # Summarization
69
+ API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
70
+ HF_API_KEY = os.environ.get("HF_API_KEY", "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")
71
+ headers = {"Authorization": f"Bearer {HF_API_KEY}"}
72
+
73
+ def summarize_text(text, max_length):
74
+ if not text.strip(): return ""
75
+ min_length = max(10, max_length // 4)
76
+ response = requests.post(API_URL, headers=headers, json={
77
+ "inputs": text,
78
+ "parameters": {"min_length": min_length, "max_length": max_length}
79
+ })
80
+ result = response.json()
81
+ return result[0]["summary_text"] if isinstance(result, list) else "Error: " + str(result)
82
+
83
+ # UI Styling
84
+ gradio_style = """
85
+ .gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; }
86
+ textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; }
87
+ textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; }
88
+ """
89
+
90
+ with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
91
+ gr.Markdown("## 🤖 AI Toolbox: Translate & Summarize")
92
+
93
+ with gr.Tab("🌐 Text Translator"):
94
+ with gr.Row():
95
+ src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🌐 From", value="English")
96
+ swap = gr.Button("⇄")
97
+ tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🎯 To", value="Korean")
98
+ with gr.Row():
99
+ input_text = gr.Textbox(lines=3, label="✍️ Input Text")
100
+ output_text = gr.Textbox(lines=3, label="📤 Translated Output", interactive=False)
101
+ with gr.Row():
102
+ translate = gr.Button("🚀 Translate", variant="primary")
103
+ clear = gr.Button("🧽 Clear")
104
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
105
+ max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
106
+ temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
107
+ with gr.Accordion("📜 Translation History", open=False):
108
+ history_json = gr.JSON(label="Recent Translations")
109
+ with gr.Row():
110
+ refresh = gr.Button("🔄 Refresh")
111
+ clear_history = gr.Button("🧹 Clear History")
112
 
113
+ with gr.Tab("📁 File Translator"):
114
+ file_input = gr.File(label="📂 Upload .txt File", file_types=[".txt"])
115
+ file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 From", value="English")
116
+ file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 To", value="Korean")
117
+ file_translate = gr.Button("📄 Translate File", variant="primary")
118
+ file_result = gr.Textbox(label="📑 File Output", lines=10, interactive=False)
119
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
120
+ f_max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
121
+ f_temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
122
 
123
+ with gr.Tab("📝 Text Summarizer"):
124
+ summary_input = gr.Textbox(lines=5, label="📚 Enter text to summarize")
125
+ summary_length = gr.Slider(32, 512, value=128, step=8, label="📏 Max Length")
126
+ summary_output = gr.Textbox(label="🧾 Summary", lines=5, interactive=False)
127
+ summary_btn = gr.Button("🧠 Summarize")
 
 
 
 
128
 
129
+ # Button events
130
+ translate.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text)
131
+ clear.click(lambda: ("", ""), None, [input_text, output_text])
132
+ swap.click(lambda s, t: (t, s), [src_lang, tgt_lang], [src_lang, tgt_lang])
133
+ refresh.click(lambda: history.get(), None, history_json)
134
+ clear_history.click(lambda: history.clear() or [], None, history_json)
135
+ file_translate.click(lambda file, src, tgt, ml, t: translate_file(file.read(), src, tgt, ml, t),
136
+ [file_input, file_src, file_tgt, f_max_length, f_temp], file_result)
137
+ summary_btn.click(summarize_text, [summary_input, summary_length], summary_output)
138
 
139
+ gr.Markdown(f"""
140
+ ### 🔍 Info
141
+ - Translator Model: `{model_name}` on `{device}`
142
+ - Summarizer Model: `facebook/bart-large-cnn`
143
+ - HF API Key: {'Loaded ✅' if HF_API_KEY else 'Missing ❌'}
144
+ """)
145
 
146
+ if __name__ == "__main__":
147
+ demo.launch(share=True)