fereen5 commited on
Commit
f20a943
·
verified ·
1 Parent(s): e3d1d34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -6,14 +6,14 @@ from functools import lru_cache
6
  import torch
7
  import requests
8
 
9
- # Language Codes
10
  LANGUAGE_CODES = {
11
  "English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
12
  "Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
13
  "Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
14
  }
15
 
16
- # Translation History
17
  class TranslationHistory:
18
  def __init__(self):
19
  self.history = []
@@ -25,21 +25,19 @@ class TranslationHistory:
25
  })
26
  if len(self.history) > 100:
27
  self.history.pop()
28
- def get(self):
29
- return self.history
30
- def clear(self):
31
- self.history = []
32
 
33
  history = TranslationHistory()
34
 
35
- # Load Translation Model
36
  model_name = "facebook/nllb-200-distilled-600M"
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
  model.to(device)
41
 
42
- # Translation Function
43
  @lru_cache(maxsize=512)
44
  def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
45
  if not text.strip(): return ""
@@ -48,7 +46,8 @@ def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
48
  input_tokens = tokenizer(text, return_tensors="pt", padding=True)
49
  input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
50
  forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
51
- output = model.generate(**input_tokens,
 
52
  forced_bos_token_id=forced_bos_token_id,
53
  max_length=max_length, temperature=temperature,
54
  num_beams=5, early_stopping=True
@@ -65,9 +64,9 @@ def translate_file(file, src_lang, tgt_lang, max_length, temperature):
65
  except Exception as e:
66
  return f"File translation error: {e}"
67
 
68
- # Summarization
69
  API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
70
- HF_API_KEY = os.environ.get("HF_API_KEY", "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")
71
  headers = {"Authorization": f"Bearer {HF_API_KEY}"}
72
 
73
  def summarize_text(text, max_length):
@@ -80,13 +79,14 @@ def summarize_text(text, max_length):
80
  result = response.json()
81
  return result[0]["summary_text"] if isinstance(result, list) else "Error: " + str(result)
82
 
83
- # UI Styling
84
  gradio_style = """
85
  .gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; }
86
  textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; }
87
  textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; }
88
  """
89
 
 
90
  with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
91
  gr.Markdown("## 🤖 AI Toolbox: Translate & Summarize")
92
 
@@ -95,30 +95,35 @@ with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
95
  src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🌐 From", value="English")
96
  swap = gr.Button("⇄")
97
  tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🎯 To", value="Korean")
 
98
  with gr.Row():
99
  input_text = gr.Textbox(lines=3, label="✍️ Input Text")
100
  output_text = gr.Textbox(lines=3, label="📤 Translated Output", interactive=False)
 
101
  with gr.Row():
102
  translate = gr.Button("🚀 Translate", variant="primary")
103
  clear = gr.Button("🧽 Clear")
 
104
  with gr.Accordion("⚙️ Advanced Settings", open=False):
105
  max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
106
  temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
 
107
  with gr.Accordion("📜 Translation History", open=False):
108
  history_json = gr.JSON(label="Recent Translations")
109
  with gr.Row():
110
  refresh = gr.Button("🔄 Refresh")
111
  clear_history = gr.Button("🧹 Clear History")
112
 
113
- with gr.Tab("📁 File Translator"):
114
- file_input = gr.File(label="📂 Upload .txt File", file_types=[".txt"])
115
- file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 From", value="English")
116
- file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 To", value="Korean")
117
- file_translate = gr.Button("📄 Translate File", variant="primary")
118
- file_result = gr.Textbox(label="📑 File Output", lines=10, interactive=False)
119
- with gr.Accordion("⚙️ Advanced Settings", open=False):
120
- f_max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
121
- f_temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
 
122
 
123
  with gr.Tab("📝 Text Summarizer"):
124
  summary_input = gr.Textbox(lines=5, label="📚 Enter text to summarize")
@@ -126,7 +131,7 @@ with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
126
  summary_output = gr.Textbox(label="🧾 Summary", lines=5, interactive=False)
127
  summary_btn = gr.Button("🧠 Summarize")
128
 
129
- # Button events
130
  translate.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text)
131
  clear.click(lambda: ("", ""), None, [input_text, output_text])
132
  swap.click(lambda s, t: (t, s), [src_lang, tgt_lang], [src_lang, tgt_lang])
@@ -138,9 +143,9 @@ with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
138
 
139
  gr.Markdown(f"""
140
  ### 🔍 Info
141
- - Translator Model: `{model_name}` on `{device}`
142
- - Summarizer Model: `facebook/bart-large-cnn`
143
- - HF API Key: {'Loaded ✅' if HF_API_KEY else 'Missing ❌'}
144
  """)
145
 
146
  if __name__ == "__main__":
 
6
  import torch
7
  import requests
8
 
9
+ # Language codes
10
  LANGUAGE_CODES = {
11
  "English": "eng_Latn", "Korean": "kor_Hang", "Japanese": "jpn_Jpan", "Chinese": "zho_Hans",
12
  "Spanish": "spa_Latn", "French": "fra_Latn", "German": "deu_Latn", "Russian": "rus_Cyrl",
13
  "Portuguese": "por_Latn", "Italian": "ita_Latn", "Burmese": "mya_Mymr", "Thai": "tha_Thai"
14
  }
15
 
16
+ # Translation History class
17
  class TranslationHistory:
18
  def __init__(self):
19
  self.history = []
 
25
  })
26
  if len(self.history) > 100:
27
  self.history.pop()
28
+ def get(self): return self.history
29
+ def clear(self): self.history = []
 
 
30
 
31
  history = TranslationHistory()
32
 
33
+ # Load translation model
34
  model_name = "facebook/nllb-200-distilled-600M"
35
  tokenizer = AutoTokenizer.from_pretrained(model_name)
36
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
  model.to(device)
39
 
40
+ # Translate function
41
  @lru_cache(maxsize=512)
42
  def cached_translate(text, src_lang, tgt_lang, max_length=128, temperature=0.7):
43
  if not text.strip(): return ""
 
46
  input_tokens = tokenizer(text, return_tensors="pt", padding=True)
47
  input_tokens = {k: v.to(device) for k, v in input_tokens.items()}
48
  forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
49
+ output = model.generate(
50
+ **input_tokens,
51
  forced_bos_token_id=forced_bos_token_id,
52
  max_length=max_length, temperature=temperature,
53
  num_beams=5, early_stopping=True
 
64
  except Exception as e:
65
  return f"File translation error: {e}"
66
 
67
+ # Hugging Face Summarization API
68
  API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
69
+ HF_API_KEY = os.environ.get("HF_API_KEY")
70
  headers = {"Authorization": f"Bearer {HF_API_KEY}"}
71
 
72
  def summarize_text(text, max_length):
 
79
  result = response.json()
80
  return result[0]["summary_text"] if isinstance(result, list) else "Error: " + str(result)
81
 
82
+ # UI Style
83
  gradio_style = """
84
  .gr-button { border-radius: 12px !important; padding: 10px 20px !important; font-weight: bold; }
85
  textarea, input[type=text] { border: 2px solid #00ADB5 !important; border-radius: 10px; transition: 0.2s; }
86
  textarea:focus, input[type=text]:focus { border-color: #FF5722 !important; box-shadow: 0 0 8px #FF5722 !important; }
87
  """
88
 
89
+ # Gradio App
90
  with gr.Blocks(css=gradio_style, theme=gr.themes.Soft()) as demo:
91
  gr.Markdown("## 🤖 AI Toolbox: Translate & Summarize")
92
 
 
95
  src_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🌐 From", value="English")
96
  swap = gr.Button("⇄")
97
  tgt_lang = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="🎯 To", value="Korean")
98
+
99
  with gr.Row():
100
  input_text = gr.Textbox(lines=3, label="✍️ Input Text")
101
  output_text = gr.Textbox(lines=3, label="📤 Translated Output", interactive=False)
102
+
103
  with gr.Row():
104
  translate = gr.Button("🚀 Translate", variant="primary")
105
  clear = gr.Button("🧽 Clear")
106
+
107
  with gr.Accordion("⚙️ Advanced Settings", open=False):
108
  max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
109
  temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
110
+
111
  with gr.Accordion("📜 Translation History", open=False):
112
  history_json = gr.JSON(label="Recent Translations")
113
  with gr.Row():
114
  refresh = gr.Button("🔄 Refresh")
115
  clear_history = gr.Button("🧹 Clear History")
116
 
117
+ with gr.Tab("📁 File Translator"):
118
+ file_input = gr.File(label="📂 Upload .txt File", file_types=[".txt"])
119
+ file_src = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 From", value="English")
120
+ file_tgt = gr.Dropdown(list(LANGUAGE_CODES.keys()), label="📌 To", value="Korean")
121
+ file_translate = gr.Button("📄 Translate File", variant="primary")
122
+ file_result = gr.Textbox(label="📑 File Output", lines=10, interactive=False)
123
+
124
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
125
+ f_max_length = gr.Slider(10, 512, value=128, step=1, label="Max Length")
126
+ f_temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
127
 
128
  with gr.Tab("📝 Text Summarizer"):
129
  summary_input = gr.Textbox(lines=5, label="📚 Enter text to summarize")
 
131
  summary_output = gr.Textbox(label="🧾 Summary", lines=5, interactive=False)
132
  summary_btn = gr.Button("🧠 Summarize")
133
 
134
+ # Event Hooks
135
  translate.click(cached_translate, [input_text, src_lang, tgt_lang, max_length, temperature], output_text)
136
  clear.click(lambda: ("", ""), None, [input_text, output_text])
137
  swap.click(lambda s, t: (t, s), [src_lang, tgt_lang], [src_lang, tgt_lang])
 
143
 
144
  gr.Markdown(f"""
145
  ### 🔍 Info
146
+ - Translator: `{model_name}` on `{device}`
147
+ - Summarizer: `facebook/bart-large-cnn`
148
+ - API Token Status: {'✅ Loaded' if HF_API_KEY else '❌ Not Found'}
149
  """)
150
 
151
  if __name__ == "__main__":