xTHExBEASTx commited on
Commit
82da1ff
·
verified ·
1 Parent(s): 2baa60d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -16
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  import srt
 
4
  import os
5
 
6
  # --- Configuration ---
@@ -8,30 +9,45 @@ MODEL_CHECKPOINT = "facebook/nllb-200-distilled-600M"
8
  SRC_LANG = "eng_Latn"
9
  TGT_LANG = "arb_Arab"
10
 
11
- # --- Load Model ---
12
  print("Loading model...")
13
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
15
-
16
- translator = pipeline(
17
- "translation",
18
- model=model,
19
- tokenizer=tokenizer,
20
- src_lang=SRC_LANG,
21
- tgt_lang=TGT_LANG,
22
- device=-1
23
- )
24
 
25
  def batch_translate(texts, batch_size=8):
 
 
 
26
  results = []
 
 
 
 
27
  for i in range(0, len(texts), batch_size):
28
  batch = texts[i : i + batch_size]
29
- outputs = translator(batch, max_length=400)
30
- results.extend([out['translation_text'] for out in outputs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  return results
32
 
33
  def process_srt(filepath):
34
- # Gradio 4 passes the file path as a string
35
  if filepath is None:
36
  return None
37
 
@@ -43,12 +59,15 @@ def process_srt(filepath):
43
  except Exception as e:
44
  return f"Error parsing SRT: {str(e)}"
45
 
 
46
  texts_to_translate = [sub.content for sub in subtitles]
47
  translated_texts = batch_translate(texts_to_translate)
48
 
 
49
  for sub, trans_text in zip(subtitles, translated_texts):
50
  sub.content = trans_text
51
 
 
52
  output_path = "translated_subtitles.srt"
53
  with open(output_path, 'w', encoding='utf-8') as f:
54
  f.write(srt.compose(subtitles))
@@ -56,7 +75,7 @@ def process_srt(filepath):
56
  return output_path
57
 
58
  # --- Gradio Interface ---
59
- with gr.Blocks(title="SRT Translator") as demo:
60
  gr.Markdown("# 🇬🇧 English to 🇸🇦 Arabic SRT Translator")
61
 
62
  with gr.Row():
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import srt
4
+ import torch
5
  import os
6
 
7
  # --- Configuration ---
 
9
  SRC_LANG = "eng_Latn"
10
  TGT_LANG = "arb_Arab"
11
 
12
+ # --- Load Model Directly (No Pipeline) ---
13
  print("Loading model...")
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
15
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)
16
+ print("Model loaded!")
 
 
 
 
 
 
 
17
 
18
  def batch_translate(texts, batch_size=8):
19
+ """
20
+ Directly uses the model to translate without the pipeline abstraction.
21
+ """
22
  results = []
23
+
24
+ # 1. Set the source language for the tokenizer
25
+ tokenizer.src_lang = SRC_LANG
26
+
27
  for i in range(0, len(texts), batch_size):
28
  batch = texts[i : i + batch_size]
29
+
30
+ # 2. Tokenize the batch
31
+ inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
32
+
33
+ # 3. Generate translation (Force the target language ID)
34
+ # NLLB requires forcing the 'bos_token_id' to the target language
35
+ forced_bos_token_id = tokenizer.lang_code_to_id[TGT_LANG]
36
+
37
+ with torch.no_grad():
38
+ generated_tokens = model.generate(
39
+ **inputs,
40
+ forced_bos_token_id=forced_bos_token_id,
41
+ max_length=512
42
+ )
43
+
44
+ # 4. Decode the results
45
+ batch_results = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
46
+ results.extend(batch_results)
47
+
48
  return results
49
 
50
  def process_srt(filepath):
 
51
  if filepath is None:
52
  return None
53
 
 
59
  except Exception as e:
60
  return f"Error parsing SRT: {str(e)}"
61
 
62
+ # Translate content
63
  texts_to_translate = [sub.content for sub in subtitles]
64
  translated_texts = batch_translate(texts_to_translate)
65
 
66
+ # Update subtitles
67
  for sub, trans_text in zip(subtitles, translated_texts):
68
  sub.content = trans_text
69
 
70
+ # Save output
71
  output_path = "translated_subtitles.srt"
72
  with open(output_path, 'w', encoding='utf-8') as f:
73
  f.write(srt.compose(subtitles))
 
75
  return output_path
76
 
77
  # --- Gradio Interface ---
78
+ with gr.Blocks(title="NLLB SRT Translator") as demo:
79
  gr.Markdown("# 🇬🇧 English to 🇸🇦 Arabic SRT Translator")
80
 
81
  with gr.Row():