Nick021402 commited on
Commit
30bc0de
·
verified ·
1 Parent(s): ee67b24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -36
app.py CHANGED
@@ -8,6 +8,8 @@ import subprocess
8
  import logging
9
  from typing import Optional, Tuple
10
  import re
 
 
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
@@ -30,12 +32,19 @@ class SubtitleTranslator:
30
  if self.translator is None:
31
  logger.info("Loading translation model...")
32
  # Use a lightweight translation model
33
- self.translator = pipeline(
34
- "translation",
35
- model="Helsinki-NLP/opus-mt-mul-en",
36
- device=0 if self.device == "cuda" else -1,
37
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
38
- )
 
 
 
 
 
 
 
39
 
40
  def extract_audio(self, video_path: str) -> str:
41
  """Extract audio from video file"""
@@ -80,32 +89,16 @@ class SubtitleTranslator:
80
  if source_lang == "en":
81
  return text
82
 
83
- # Chunk long text to avoid memory issues
84
- max_length = 500
85
- if len(text) <= max_length:
86
- result = self.translator(text, max_length=512)
87
- return result[0]['translation_text']
88
-
89
- # Process in chunks
90
- sentences = re.split(r'[.!?]+', text)
91
- translated_chunks = []
92
- current_chunk = ""
93
 
94
- for sentence in sentences:
95
- if len(current_chunk + sentence) <= max_length:
96
- current_chunk += sentence + ". "
97
- else:
98
- if current_chunk:
99
- result = self.translator(current_chunk.strip(), max_length=512)
100
- translated_chunks.append(result[0]['translation_text'])
101
- current_chunk = sentence + ". "
102
-
103
- if current_chunk:
104
- result = self.translator(current_chunk.strip(), max_length=512)
105
- translated_chunks.append(result[0]['translation_text'])
106
-
107
- return " ".join(translated_chunks)
108
-
109
  except Exception as e:
110
  logger.error(f"Translation failed: {e}")
111
  return text # Return original if translation fails
@@ -294,8 +287,4 @@ def create_interface():
294
  # Launch the app
295
  if __name__ == "__main__":
296
  demo = create_interface()
297
- demo.launch(
298
- server_name="0.0.0.0",
299
- server_port=7860,
300
- share=true
301
- )
 
8
  import logging
9
  from typing import Optional, Tuple
10
  import re
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
 
32
  if self.translator is None:
33
  logger.info("Loading translation model...")
34
  # Use a lightweight translation model
35
+ try:
36
+ self.translator = pipeline(
37
+ "translation",
38
+ model="Helsinki-NLP/opus-mt-mul-en",
39
+ device=0 if self.device == "cuda" else -1
40
+ )
41
+ except Exception as e:
42
+ logger.warning(f"Failed to load Helsinki model, using Facebook model: {e}")
43
+ self.translator = pipeline(
44
+ "translation",
45
+ model="facebook/m2m100_418M",
46
+ device=0 if self.device == "cuda" else -1
47
+ )
48
 
49
  def extract_audio(self, video_path: str) -> str:
50
  """Extract audio from video file"""
 
89
  if source_lang == "en":
90
  return text
91
 
92
+ # For Helsinki model, use direct translation
93
+ if "Helsinki" in str(type(self.translator.model)):
94
+ result = self.translator(text)
95
+ return result[0]['translation_text'] if result else text
 
 
 
 
 
 
96
 
97
+ # For M2M100 model, specify target language
98
+ else:
99
+ result = self.translator(text, forced_bos_token_id=self.translator.tokenizer.get_lang_id("en"))
100
+ return result[0]['translation_text'] if result else text
101
+
 
 
 
 
 
 
 
 
 
 
102
  except Exception as e:
103
  logger.error(f"Translation failed: {e}")
104
  return text # Return original if translation fails
 
287
  # Launch the app
288
  if __name__ == "__main__":
289
  demo = create_interface()
290
+ demo.launch(share=True)