Maria-tamu commited on
Commit
fcc1fd6
·
verified ·
1 Parent(s): 4000c0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -134
app.py CHANGED
@@ -1,143 +1,91 @@
1
- import tkinter as tk
2
- from tkinter import ttk
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
 
 
 
5
  class TranslationPipeline:
6
- def __init__(self, model_name):
7
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
9
 
10
- def __call__(self, text):
11
- inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
12
- outputs = self.model.generate(**inputs)
 
 
 
 
13
  translated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
14
- return [{'translation_text': translated_text}]
15
-
16
- class TranslateApp:
17
- def __init__(self, root):
18
- self.root = root
19
- self.root.title("Short Translation")
20
- self.root.geometry("700x300")
21
- self.root.configure(bg="#f0f0f0")
22
-
23
- # Define Colors
24
- self.maroon = "#800000"
25
- self.white = "#ffffff"
26
-
27
- # Initialize Translation Models (pipeline parameters)
28
- self.models = {
29
- "Spanish": ["translation_en_to_es",
30
- "Helsinki-NLP/opus-mt-en-es"],
31
- "German": ["translation_en_to_de",
32
- "Helsinki-NLP/opus-mt-en-de"],
33
- "Japanese": ["translation_en_to_ja",
34
- "staka/fugumt-en-ja"],
35
- "Ukrainian": ["translation_en_to_uk",
36
- "Helsinki-NLP/opus-mt-en-uk"],
37
- "Russian": ["translation_en_to_ru",
38
- "Helsinki-NLP/opus-mt-en-ru"],
39
- }
40
-
41
- # Cache for loaded pipeline objects
42
- # This is done to speed up translations after the first.
43
- self.cached_pipelines = {}
44
-
45
- # --- Top Section ---
46
- top_frame = tk.Frame(self.root, bg="#f0f0f0")
47
- top_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=10)
48
-
49
- # Left Column: Input and Buttons
50
- left_column = tk.Frame(top_frame, bg="#f0f0f0", borderwidth=2, relief="groove")
51
- left_column.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 5))
52
-
53
- tk.Label(left_column, text="English Sentence", fg=self.maroon,
54
- bg="#f0f0f0", font=("Arial", 12, "bold")).pack(anchor="w", padx=5)
55
-
56
- self.input_entry = tk.Entry(left_column, font=("Arial", 12))
57
- self.input_entry.pack(fill=tk.X, padx=10, pady=5)
58
-
59
- button_frame = tk.Frame(left_column, bg="#f0f0f0")
60
- button_frame.pack(pady=5)
61
-
62
- self.translate_btn = tk.Button(button_frame, text="Translate",
63
- fg=self.maroon, bg=self.white,
64
- command=self.translate_text, width=15)
65
- self.translate_btn.pack(side=tk.LEFT, padx=5)
66
-
67
- self.clear_btn = tk.Button(button_frame, text="Clear", fg=self.maroon,
68
- bg=self.white,
69
- command=self.clear_fields, width=15)
70
- self.clear_btn.pack(side=tk.LEFT, padx=5)
71
-
72
- # Right Column: Language Selection
73
- right_column = tk.Frame(top_frame, bg="#f0f0f0", borderwidth=2,
74
- relief="groove")
75
- right_column.pack(side=tk.RIGHT, fill=tk.Y, padx=(5, 0))
76
-
77
- tk.Label(right_column, text="Translation Language", fg=self.maroon,
78
- bg="#f0f0f0", font=("Arial", 12, "bold")).pack(anchor="w", padx=5)
79
-
80
- self.lang_var = tk.StringVar(value="Spanish")
81
- languages = [("Spanish", "es"), ("German", "de"), ("Japanese", "ja"),
82
- ("Ukrainian", "uk"), ("Russian", "ru")]
83
-
84
- for lang_text, lang_code in languages:
85
- tk.Radiobutton(right_column, text=lang_text,
86
- font=("Arial", 12, "bold"),
87
- variable=self.lang_var, value=lang_text,
88
- fg=self.maroon, bg="#f0f0f0", highlightthickness=0,
89
- activeforeground=self.maroon).pack(anchor="w", padx=30)
90
-
91
- # --- Bottom Section: Output Area ---
92
- self.output_frame = tk.Frame(self.root, bg=self.maroon, height=120)
93
- self.output_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=(0, 10))
94
- self.output_frame.pack_propagate(False) # Maintain fixed height
95
-
96
- self.output_label = tk.Label(self.output_frame, text="", fg=self.white,
97
- bg=self.maroon,
98
- font=("Times New Roman", 16, "bold"),
99
- wraplength=650, justify="center")
100
- self.output_label.pack(expand=True, fill=tk.BOTH)
101
-
102
- def translate_text(self):
103
- input_text = self.input_entry.get().strip()
104
- target_lang = self.lang_var.get()
105
-
106
- if not input_text:
107
- self.output_label.config(text="Please enter text to translate.")
108
- return
109
-
110
- # Check if we need to load the model
111
- if target_lang not in self.cached_pipelines:
112
- self.output_label.config(text=f"Loading model for {target_lang}...")
113
- self.root.update_idletasks()
114
-
115
- try:
116
- # We ignore the task name (index 0) since we are using the custom pipeline
117
- _, model_name = self.models[target_lang]
118
- self.cached_pipelines[target_lang] = TranslationPipeline(model_name)
119
- except Exception as e:
120
- self.output_label.config(text=f"Error loading model: {str(e)}")
121
- return
122
-
123
- self.output_label.config(text=f"Translating to {target_lang}...")
124
- self.root.update_idletasks()
125
-
126
- try:
127
- translator = self.cached_pipelines[target_lang]
128
- result = translator(input_text)
129
- translated_text = result[0]['translation_text']
130
-
131
- print(f"DEBUG: {target_lang} output -> {translated_text}")
132
- self.output_label.config(text=translated_text)
133
- except Exception as e:
134
- self.output_label.config(text=f"Error: {str(e)}")
135
-
136
- def clear_fields(self):
137
- self.input_entry.delete(0, tk.END)
138
- self.output_label.config(text="")
139
 
140
  if __name__ == "__main__":
141
- root = tk.Tk()
142
- app = TranslateApp(root)
143
- root.mainloop()
 
1
+ import gradio as gr
2
+ import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # -----------------------
6
+ # Translation core
7
+ # -----------------------
8
  class TranslationPipeline:
9
+ def __init__(self, model_name: str, device: str = "cpu"):
10
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
+ self.device = device
13
+ self.model.to(self.device)
14
 
15
+ @torch.inference_mode()
16
+ def __call__(self, text: str) -> str:
17
+ inputs = self.tokenizer(
18
+ text, return_tensors="pt", padding=True, truncation=True
19
+ ).to(self.device)
20
+
21
+ outputs = self.model.generate(**inputs, max_new_tokens=256)
22
  translated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+ return translated_text
24
+
25
+
26
+ # -----------------------
27
+ # Models + cache
28
+ # -----------------------
29
+ MODELS = {
30
+ "Spanish": "Helsinki-NLP/opus-mt-en-es",
31
+ "German": "Helsinki-NLP/opus-mt-en-de",
32
+ "Japanese": "staka/fugumt-en-ja",
33
+ "Ukrainian": "Helsinki-NLP/opus-mt-en-uk",
34
+ "Russian": "Helsinki-NLP/opus-mt-en-ru",
35
+ }
36
+
37
+ # Cache loaded pipelines so we don’t re-download/reload every time
38
+ PIPELINE_CACHE = {}
39
+
40
+ # Use GPU if available (some Spaces have it; many are CPU)
41
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+
44
+ def translate(text: str, target_lang: str) -> str:
45
+ text = (text or "").strip()
46
+ if not text:
47
+ return "Please enter text to translate."
48
+
49
+ if target_lang not in MODELS:
50
+ return "Unsupported language selection."
51
+
52
+ if target_lang not in PIPELINE_CACHE:
53
+ model_name = MODELS[target_lang]
54
+ # Loading can take time on first request
55
+ PIPELINE_CACHE[target_lang] = TranslationPipeline(model_name, device=DEVICE)
56
+
57
+ translator = PIPELINE_CACHE[target_lang]
58
+ return translator(text)
59
+
60
+
61
+ # -----------------------
62
+ # Gradio UI
63
+ # -----------------------
64
+ with gr.Blocks(title="Short Translation") as demo:
65
+ gr.Markdown("## Short Translation\nEnter an English sentence and choose a target language.")
66
+
67
+ with gr.Row():
68
+ with gr.Column(scale=2):
69
+ input_text = gr.Textbox(label="English Sentence", lines=3, placeholder="Type here...")
70
+ translate_btn = gr.Button("Translate")
71
+ clear_btn = gr.Button("Clear")
72
+
73
+ with gr.Column(scale=1):
74
+ target_lang = gr.Radio(
75
+ choices=list(MODELS.keys()),
76
+ value="Spanish",
77
+ label="Translation Language",
78
+ )
79
+
80
+ output_text = gr.Textbox(label="Translation", lines=4)
81
+
82
+ translate_btn.click(fn=translate, inputs=[input_text, target_lang], outputs=output_text)
83
+
84
+ def clear():
85
+ return "", ""
86
+
87
+ clear_btn.click(fn=clear, inputs=None, outputs=[input_text, output_text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  if __name__ == "__main__":
90
+ demo.launch()
91
+