entropy25 commited on
Commit
72aed53
·
verified ·
1 Parent(s): 40e753d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -37
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  from peft import PeftModel
 
 
5
 
6
  base_model_name = "facebook/nllb-200-distilled-600M"
7
  adapter_en_to_no = "entropy25/mt_en_no_oil"
@@ -9,58 +11,122 @@ adapter_no_to_en = "entropy25/mt_no_en_oil"
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
11
 
12
- print("Loading shared base model...")
 
 
13
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
14
  base_model_name,
15
- torch_dtype=torch.float16,
16
- low_cpu_mem_usage=True,
17
- device_map="auto"
18
  )
19
 
20
  print("Loading adapters...")
21
  model = PeftModel.from_pretrained(base_model, adapter_en_to_no, adapter_name="en_to_no")
22
  model.load_adapter(adapter_no_to_en, adapter_name="no_to_en")
 
 
 
 
 
 
 
23
 
24
- def translate(text, source_lang, target_lang):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if not text.strip() or source_lang == target_lang:
26
  return text
27
 
28
  if source_lang == "English" and target_lang == "Norwegian":
29
  model.set_adapter("en_to_no")
30
- src_code, tgt_code = "eng_Latn", "nob_Latn"
31
  elif source_lang == "Norwegian" and target_lang == "English":
32
  model.set_adapter("no_to_en")
33
- src_code, tgt_code = "nob_Latn", "eng_Latn"
34
  else:
35
  return "Unsupported language pair"
36
 
 
 
37
  lines = text.split('\n')
38
  non_empty_lines = [line for line in lines if line.strip()]
39
 
40
  if not non_empty_lines:
41
  return text
42
 
43
- inputs = tokenizer(
44
- non_empty_lines,
45
- return_tensors="pt",
46
- padding=True,
47
- truncation=True,
48
- max_length=512
49
- )
50
-
51
- if hasattr(model, 'device'):
52
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
53
-
54
- outputs = model.generate(
55
- **inputs,
56
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
57
- max_length=512,
58
- num_beams=3
59
- )
60
 
61
- results = tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- result_iter = iter(results)
64
  final_lines = []
65
  for line in lines:
66
  if line.strip():
@@ -70,19 +136,84 @@ def translate(text, source_lang, target_lang):
70
 
71
  return '\n'.join(final_lines)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def swap_languages(src, tgt, input_txt, output_txt):
74
  return tgt, src, output_txt, input_txt
75
 
76
  def load_file(file):
77
  if file is None:
78
  return ""
 
79
  try:
 
 
 
80
  with open(file.name, 'r', encoding='utf-8') as f:
81
- return f.read()
 
 
 
82
  except:
83
  try:
84
  with open(file.name, 'r', encoding='latin-1') as f:
85
- return f.read()
 
 
 
86
  except Exception as e:
87
  return f"Error reading file: {str(e)}"
88
 
@@ -181,17 +312,44 @@ custom_css = """
181
  background: #f8f9fa !important;
182
  border-color: #0f6fff !important;
183
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  .footer-info {
185
  text-align: center !important;
186
  color: #999 !important;
187
  font-size: 13px !important;
188
  padding: 20px !important;
189
  }
 
 
 
 
 
190
  """
191
 
192
  with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
193
  gr.HTML("<div style='height: 20px'></div>")
194
 
 
 
 
 
 
 
 
 
 
195
  with gr.Row():
196
  with gr.Column(scale=1):
197
  with gr.Group(elem_classes="translate-box"):
@@ -238,7 +396,10 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
238
  interactive=False
239
  )
240
 
241
- gr.HTML("<div class='footer-info'>Oil & Gas Translation • English ↔ Norwegian • Bidirectional Model</div>")
 
 
 
242
 
243
  with gr.Accordion("Example Sentences", open=True):
244
  with gr.Row():
@@ -249,7 +410,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
249
  max_lines=5,
250
  show_copy_button=True
251
  )
252
- use_example_btn = gr.Button("Use This Example", variant="primary", size="sm")
253
 
254
  with gr.Row():
255
  btn1 = gr.Button("Drilling (Short)", size="sm")
@@ -280,16 +441,35 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
280
 
281
  with gr.Accordion("Upload Text File", open=False):
282
  file_input = gr.File(
283
- label="Upload a .txt file to translate",
284
  file_types=[".txt"],
285
  type="filepath"
286
  )
287
 
 
 
 
 
 
 
 
 
 
 
288
  source_lang.change(fn=update_example_buttons, inputs=[source_lang], outputs=[example_text])
289
- input_text.change(fn=translate, inputs=[input_text, source_lang, target_lang], outputs=output_text)
290
- source_lang.change(fn=translate, inputs=[input_text, source_lang, target_lang], outputs=output_text)
291
- target_lang.change(fn=translate, inputs=[input_text, source_lang, target_lang], outputs=output_text)
292
- swap_btn.click(fn=swap_languages, inputs=[source_lang, target_lang, input_text, output_text], outputs=[source_lang, target_lang, input_text, output_text])
 
 
 
 
 
 
 
 
 
293
  file_input.change(fn=load_file, inputs=file_input, outputs=input_text)
294
 
295
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
4
  from peft import PeftModel
5
+ from functools import lru_cache
6
+ import os
7
 
8
  base_model_name = "facebook/nllb-200-distilled-600M"
9
  adapter_en_to_no = "entropy25/mt_en_no_oil"
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
13
 
14
+ print("Loading shared base model with 8-bit quantization...")
15
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
16
+
17
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
18
  base_model_name,
19
+ quantization_config=quantization_config,
20
+ device_map="auto",
21
+ low_cpu_mem_usage=True
22
  )
23
 
24
  print("Loading adapters...")
25
  model = PeftModel.from_pretrained(base_model, adapter_en_to_no, adapter_name="en_to_no")
26
  model.load_adapter(adapter_no_to_en, adapter_name="no_to_en")
27
+ model.eval()
28
+
29
+ QUALITY_PRESETS = {
30
+ "Professional (Best Quality)": {"num_beams": 3, "max_length": 256, "batch_size": 4},
31
+ "Balanced (Faster)": {"num_beams": 2, "max_length": 256, "batch_size": 5},
32
+ "Draft (Fastest)": {"num_beams": 2, "max_length": 128, "batch_size": 5}
33
+ }
34
 
35
+ QUALITY_TEST_CASES = {
36
+ "en_to_no": [
37
+ {
38
+ "input": "Mud weight adjusted to 1.82 specific gravity at 3,247 meters depth.",
39
+ "expected": "Slamvekt justert til 1,82 spesifikk tyngde ved 3 247 meters dybde.",
40
+ "check": ["slamvekt", "1,82", "3 247"]
41
+ },
42
+ {
43
+ "input": "Christmas tree rated for 10,000 psi working pressure.",
44
+ "expected": "Juletre dimensjonert for 10 000 psi arbeidstrykk.",
45
+ "check": ["juletre", "10 000", "psi"]
46
+ },
47
+ {
48
+ "input": "H2S training required before site access.",
49
+ "expected": "H2S-opplæring påkrevd før tilgang til området.",
50
+ "check": ["H2S", "opplæring", "påkrevd"]
51
+ },
52
+ {
53
+ "input": "Permeability is 250 millidarcy with 22 percent porosity.",
54
+ "expected": "Permeabilitet er 250 millidarcy med 22 prosent porøsitet.",
55
+ "check": ["permeabilitet", "250", "22"]
56
+ }
57
+ ],
58
+ "no_to_en": [
59
+ {
60
+ "input": "Permeabilitet er 250 millidarcy med 22 prosent porøsitet.",
61
+ "expected": "Permeability is 250 millidarcy with 22 percent porosity.",
62
+ "check": ["permeability", "250", "22"]
63
+ },
64
+ {
65
+ "input": "Subsea produksjonssystemet består av et vertikalt juletre.",
66
+ "expected": "The subsea production system consists of a vertical Christmas tree.",
67
+ "check": ["subsea", "Christmas tree", "vertical"]
68
+ },
69
+ {
70
+ "input": "Slamvekt justert til 1,82 spesifikk tyngde ved 3 247 meters dybde.",
71
+ "expected": "Mud weight adjusted to 1.82 specific gravity at 3,247 meters depth.",
72
+ "check": ["mud weight", "1.82", "3,247"]
73
+ }
74
+ ]
75
+ }
76
+
77
+ MAX_FILE_SIZE = 1024 * 1024
78
+
79
+ def translate_core(text, source_lang, target_lang, quality_preset):
80
  if not text.strip() or source_lang == target_lang:
81
  return text
82
 
83
  if source_lang == "English" and target_lang == "Norwegian":
84
  model.set_adapter("en_to_no")
85
+ tgt_code = "nob_Latn"
86
  elif source_lang == "Norwegian" and target_lang == "English":
87
  model.set_adapter("no_to_en")
88
+ tgt_code = "eng_Latn"
89
  else:
90
  return "Unsupported language pair"
91
 
92
+ preset = QUALITY_PRESETS[quality_preset]
93
+
94
  lines = text.split('\n')
95
  non_empty_lines = [line for line in lines if line.strip()]
96
 
97
  if not non_empty_lines:
98
  return text
99
 
100
+ batch_size = preset["batch_size"]
101
+ all_results = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ for i in range(0, len(non_empty_lines), batch_size):
104
+ batch = non_empty_lines[i:i+batch_size]
105
+
106
+ inputs = tokenizer(
107
+ batch,
108
+ return_tensors="pt",
109
+ padding=True,
110
+ truncation=True,
111
+ max_length=preset["max_length"]
112
+ )
113
+
114
+ if hasattr(model, 'device'):
115
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
116
+
117
+ with torch.inference_mode():
118
+ outputs = model.generate(
119
+ **inputs,
120
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_code),
121
+ max_length=preset["max_length"],
122
+ num_beams=preset["num_beams"],
123
+ early_stopping=True
124
+ )
125
+
126
+ batch_results = tokenizer.batch_decode(outputs, skip_special_tokens=True)
127
+ all_results.extend(batch_results)
128
 
129
+ result_iter = iter(all_results)
130
  final_lines = []
131
  for line in lines:
132
  if line.strip():
 
136
 
137
  return '\n'.join(final_lines)
138
 
139
+ @lru_cache(maxsize=512)
140
+ def translate_cached(text, source_lang, target_lang, quality_preset):
141
+ return translate_core(text, source_lang, target_lang, quality_preset)
142
+
143
+ def translate(text, source_lang, target_lang, quality_preset):
144
+ if len(text) > 10000:
145
+ return "Error: Text too long (max 10,000 characters)"
146
+ return translate_cached(text, source_lang, target_lang, quality_preset)
147
+
148
+ def run_quality_tests():
149
+ results = []
150
+ results.append("=== QUALITY REGRESSION TEST ===\n")
151
+
152
+ for direction, test_cases in QUALITY_TEST_CASES.items():
153
+ if direction == "en_to_no":
154
+ src_lang, tgt_lang = "English", "Norwegian"
155
+ else:
156
+ src_lang, tgt_lang = "Norwegian", "English"
157
+
158
+ results.append(f"\n{src_lang} to {tgt_lang}\n")
159
+
160
+ for i, case in enumerate(test_cases, 1):
161
+ translation = translate_core(case["input"], src_lang, tgt_lang, "Professional (Best Quality)")
162
+
163
+ passed_checks = []
164
+ failed_checks = []
165
+
166
+ for keyword in case["check"]:
167
+ if keyword.lower() in translation.lower():
168
+ passed_checks.append(keyword)
169
+ else:
170
+ failed_checks.append(keyword)
171
+
172
+ status = "PASS" if not failed_checks else "CHECK"
173
+
174
+ results.append(f"\nTest {i}: {status}")
175
+ results.append(f"Input: {case['input']}")
176
+ results.append(f"Expected: {case['expected']}")
177
+ results.append(f"Got: {translation}")
178
+
179
+ if passed_checks:
180
+ results.append(f"Found: {', '.join(passed_checks)}")
181
+ if failed_checks:
182
+ results.append(f"Missing: {', '.join(failed_checks)}")
183
+
184
+ results.append("\n=== TEST COMPLETE ===")
185
+
186
+ pass_count = sum(1 for r in results if "PASS" in r)
187
+ check_count = sum(1 for r in results if "CHECK" in r)
188
+ total = len(QUALITY_TEST_CASES["en_to_no"]) + len(QUALITY_TEST_CASES["no_to_en"])
189
+
190
+ results.insert(1, f"\nScore: {pass_count}/{total} passed, {check_count}/{total} need review\n")
191
+
192
+ return '\n'.join(results)
193
+
194
  def swap_languages(src, tgt, input_txt, output_txt):
195
  return tgt, src, output_txt, input_txt
196
 
197
  def load_file(file):
198
  if file is None:
199
  return ""
200
+
201
  try:
202
+ if os.path.getsize(file.name) > MAX_FILE_SIZE:
203
+ return "Error: File too large (max 1MB)"
204
+
205
  with open(file.name, 'r', encoding='utf-8') as f:
206
+ content = f.read()
207
+ if len(content) > 10000:
208
+ return "Error: File content too long (max 10,000 characters)"
209
+ return content
210
  except:
211
  try:
212
  with open(file.name, 'r', encoding='latin-1') as f:
213
+ content = f.read()
214
+ if len(content) > 10000:
215
+ return "Error: File content too long (max 10,000 characters)"
216
+ return content
217
  except Exception as e:
218
  return f"Error reading file: {str(e)}"
219
 
 
312
  background: #f8f9fa !important;
313
  border-color: #0f6fff !important;
314
  }
315
+ .translate-btn {
316
+ background: #0f6fff !important;
317
+ color: white !important;
318
+ border: none !important;
319
+ padding: 12px 24px !important;
320
+ font-size: 15px !important;
321
+ font-weight: 500 !important;
322
+ border-radius: 4px !important;
323
+ cursor: pointer !important;
324
+ }
325
+ .translate-btn:hover {
326
+ background: #0d5dd9 !important;
327
+ }
328
  .footer-info {
329
  text-align: center !important;
330
  color: #999 !important;
331
  font-size: 13px !important;
332
  padding: 20px !important;
333
  }
334
+ .quality-selector {
335
+ background: #f0f7ff !important;
336
+ border: 1px solid #0f6fff !important;
337
+ border-radius: 4px !important;
338
+ }
339
  """
340
 
341
  with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
342
  gr.HTML("<div style='height: 20px'></div>")
343
 
344
+ with gr.Row():
345
+ quality_preset = gr.Radio(
346
+ choices=list(QUALITY_PRESETS.keys()),
347
+ value="Professional (Best Quality)",
348
+ label="Translation Quality",
349
+ info="Professional: beam=3, max=256 | Balanced: beam=2, max=256 | Draft: beam=2, max=128",
350
+ elem_classes="quality-selector"
351
+ )
352
+
353
  with gr.Row():
354
  with gr.Column(scale=1):
355
  with gr.Group(elem_classes="translate-box"):
 
396
  interactive=False
397
  )
398
 
399
+ with gr.Row():
400
+ translate_btn = gr.Button("Translate", variant="primary", elem_classes="translate-btn", size="lg")
401
+
402
+ gr.HTML("<div class='footer-info'>Oil & Gas Translation • English ↔ Norwegian • Optimized for HF Space</div>")
403
 
404
  with gr.Accordion("Example Sentences", open=True):
405
  with gr.Row():
 
410
  max_lines=5,
411
  show_copy_button=True
412
  )
413
+ use_example_btn = gr.Button("Use This Example", variant="primary", size="sm")
414
 
415
  with gr.Row():
416
  btn1 = gr.Button("Drilling (Short)", size="sm")
 
441
 
442
  with gr.Accordion("Upload Text File", open=False):
443
  file_input = gr.File(
444
+ label="Upload a .txt file to translate (max 1MB)",
445
  file_types=[".txt"],
446
  type="filepath"
447
  )
448
 
449
+ with gr.Accordion("Quality Test (Developer)", open=False):
450
+ test_output = gr.Textbox(
451
+ label="Test Results",
452
+ lines=20,
453
+ max_lines=30,
454
+ interactive=False
455
+ )
456
+ run_test_btn = gr.Button("Run Quality Regression Test", variant="secondary")
457
+ run_test_btn.click(fn=run_quality_tests, outputs=test_output)
458
+
459
  source_lang.change(fn=update_example_buttons, inputs=[source_lang], outputs=[example_text])
460
+
461
+ translate_btn.click(
462
+ fn=translate,
463
+ inputs=[input_text, source_lang, target_lang, quality_preset],
464
+ outputs=output_text
465
+ )
466
+
467
+ swap_btn.click(
468
+ fn=swap_languages,
469
+ inputs=[source_lang, target_lang, input_text, output_txt],
470
+ outputs=[source_lang, target_lang, input_text, output_text]
471
+ )
472
+
473
  file_input.change(fn=load_file, inputs=file_input, outputs=input_text)
474
 
475
+ demo.queue(concurrency_count=1, max_size=20).launch()