Navya-Sree commited on
Commit
0597381
·
verified ·
1 Parent(s): 0e17eea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import spaces
2
  import gradio as gr
3
  from sacremoses import MosesPunctNormalizer
4
- from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
5
  from transformers import pipeline
6
  from cultural_model import CulturalM2M100
7
  from cultural_tokenizer import CulturalTokenizer
@@ -11,7 +10,9 @@ import nltk
11
  from functools import lru_cache
12
  from config import LANGUAGE_MAPPING, ENDANGERED_LANGS, MODEL_NAME
13
 
 
14
  nltk.download("punkt_tab")
 
15
 
16
  # Device configuration
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -27,16 +28,24 @@ punct_normalizer = MosesPunctNormalizer(lang="en")
27
 
28
  @lru_cache(maxsize=202)
29
  def get_language_specific_sentence_splitter(language_code):
30
- splitter = get_split_algo(language_code[:3], "default")
31
- return splitter
 
 
 
 
32
 
33
  @spaces.GPU
34
  def translate(text: str, src_lang: str, tgt_lang: str):
35
  if not text.strip():
36
  return ""
37
 
38
- src_code = LANGUAGE_MAPPING[src_lang]["code"]
39
- tgt_code = LANGUAGE_MAPPING[tgt_lang]["code"]
 
 
 
 
40
 
41
  # Enable cultural preservation for endangered languages
42
  cultural_preservation = tgt_lang in ENDANGERED_LANGS
@@ -48,8 +57,12 @@ def translate(text: str, src_lang: str, tgt_lang: str):
48
  translated_paragraphs = []
49
 
50
  for paragraph in paragraphs:
 
 
 
 
51
  splitter = get_language_specific_sentence_splitter(src_code)
52
- sentences = list(splitter(paragraph))
53
  translated_sentences = []
54
 
55
  for sentence in sentences:
 
1
  import spaces
2
  import gradio as gr
3
  from sacremoses import MosesPunctNormalizer
 
4
  from transformers import pipeline
5
  from cultural_model import CulturalM2M100
6
  from cultural_tokenizer import CulturalTokenizer
 
10
  from functools import lru_cache
11
  from config import LANGUAGE_MAPPING, ENDANGERED_LANGS, MODEL_NAME
12
 
13
+ # Download required NLTK data
14
  nltk.download("punkt_tab")
15
+ nltk.download("punkt")
16
 
17
  # Device configuration
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
28
 
29
  @lru_cache(maxsize=202)
30
  def get_language_specific_sentence_splitter(language_code):
31
+ """Return a sentence splitter function for the given language"""
32
+ # For endangered languages, use NLTK with language-specific tokenizer
33
+ if language_code in ["qu", "ay", "chr"]: # Endangered language codes
34
+ return lambda text: nltk.sent_tokenize(text, language="english")
35
+ # For other languages, use NLTK with default tokenizer
36
+ return nltk.sent_tokenize
37
 
38
  @spaces.GPU
39
  def translate(text: str, src_lang: str, tgt_lang: str):
40
  if not text.strip():
41
  return ""
42
 
43
+ src_info = LANGUAGE_MAPPING.get(src_lang)
44
+ tgt_info = LANGUAGE_MAPPING.get(tgt_lang)
45
+ if not src_info or not tgt_info:
46
+ raise gr.Error("Invalid language selection")
47
+ src_code = src_info["code"]
48
+ tgt_code = tgt_info["code"]
49
 
50
  # Enable cultural preservation for endangered languages
51
  cultural_preservation = tgt_lang in ENDANGERED_LANGS
 
57
  translated_paragraphs = []
58
 
59
  for paragraph in paragraphs:
60
+ if not paragraph.strip():
61
+ translated_paragraphs.append("")
62
+ continue
63
+
64
  splitter = get_language_specific_sentence_splitter(src_code)
65
+ sentences = splitter(paragraph)
66
  translated_sentences = []
67
 
68
  for sentence in sentences: