ClaudBarbara commited on
Commit
8071884
·
verified ·
1 Parent(s): 9cc9e17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -9,7 +9,6 @@ from sacremoses import MosesPunctNormalizer
9
 
10
  app = Flask(__name__)
11
 
12
- # Preprocessing NLLB
13
  mpn = MosesPunctNormalizer(lang="en")
14
  mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]
15
 
@@ -29,15 +28,15 @@ def preprocess_text(text: str) -> str:
29
  clean = unicodedata.normalize("NFKC", clean)
30
  return clean
31
 
32
- # Load model
33
  print("Loading model...")
34
  MODEL_ID = "ClaudBarbara/Open_Access_Khmer"
35
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
36
  tokenizer = NllbTokenizerFast.from_pretrained(MODEL_ID)
37
  print("Model loaded!")
38
 
 
 
39
  def segment_text(text, src_lang):
40
- """Segment text into sentences"""
41
  if src_lang == "khm_Khmr":
42
  sentences = re.split(r'(?<=[។៖])\s*', text)
43
  else:
@@ -45,9 +44,8 @@ def segment_text(text, src_lang):
45
  return [s.strip() for s in sentences if s.strip()]
46
 
47
  def translate_batch(texts, src_lang, tgt_lang):
48
- """Translate a batch of sentences"""
49
  if not texts:
50
- return []
51
 
52
  tokenizer.src_lang = src_lang
53
  inputs = tokenizer(
@@ -64,13 +62,22 @@ def translate_batch(texts, src_lang, tgt_lang):
64
  forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
65
  max_new_tokens=int(32 + 3 * inputs.input_ids.shape[1]),
66
  num_beams=4,
67
- early_stopping=True
 
 
68
  )
69
 
70
- return tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
 
 
 
 
 
 
 
71
 
72
  def translate_long(text, src_lang, tgt_lang, batch_size=8):
73
- """Translate long document in batches"""
74
  start_time = time.time()
75
 
76
  clean_text = preprocess_text(text)
@@ -80,20 +87,26 @@ def translate_long(text, src_lang, tgt_lang, batch_size=8):
80
  return "", {}
81
 
82
  translated_parts = []
 
83
 
84
  for i in range(0, len(sentences), batch_size):
85
  batch = sentences[i:i + batch_size]
86
- translations = translate_batch(batch, src_lang, tgt_lang)
87
  translated_parts.extend(translations)
 
88
 
89
  result = " ".join(translated_parts)
90
  elapsed = time.time() - start_time
91
 
 
 
 
92
  metrics = {
93
- "sentences": len(sentences),
94
- "source_chars": len(text),
95
- "target_chars": len(result),
96
- "time_seconds": round(elapsed, 2)
 
97
  }
98
 
99
  return result, metrics
 
9
 
10
  app = Flask(__name__)
11
 
 
12
  mpn = MosesPunctNormalizer(lang="en")
13
  mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]
14
 
 
28
  clean = unicodedata.normalize("NFKC", clean)
29
  return clean
30
 
 
31
  print("Loading model...")
32
  MODEL_ID = "ClaudBarbara/Open_Access_Khmer"
33
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
34
  tokenizer = NllbTokenizerFast.from_pretrained(MODEL_ID)
35
  print("Model loaded!")
36
 
37
+ CONFIDENCE_THRESHOLD = 70 # Below this = human review needed
38
+
39
  def segment_text(text, src_lang):
 
40
  if src_lang == "khm_Khmr":
41
  sentences = re.split(r'(?<=[។៖])\s*', text)
42
  else:
 
44
  return [s.strip() for s in sentences if s.strip()]
45
 
46
  def translate_batch(texts, src_lang, tgt_lang):
 
47
  if not texts:
48
+ return [], []
49
 
50
  tokenizer.src_lang = src_lang
51
  inputs = tokenizer(
 
62
  forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
63
  max_new_tokens=int(32 + 3 * inputs.input_ids.shape[1]),
64
  num_beams=4,
65
+ early_stopping=True,
66
+ return_dict_in_generate=True,
67
+ output_scores=True
68
  )
69
 
70
+ translations = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
71
+
72
+ # Extract confidence scores
73
+ if hasattr(outputs, 'sequences_scores') and outputs.sequences_scores is not None:
74
+ scores = torch.sigmoid(outputs.sequences_scores).tolist()
75
+ else:
76
+ scores = [0.85] * len(texts)
77
+
78
+ return translations, scores
79
 
80
  def translate_long(text, src_lang, tgt_lang, batch_size=8):
 
81
  start_time = time.time()
82
 
83
  clean_text = preprocess_text(text)
 
87
  return "", {}
88
 
89
  translated_parts = []
90
+ all_scores = []
91
 
92
  for i in range(0, len(sentences), batch_size):
93
  batch = sentences[i:i + batch_size]
94
+ translations, scores = translate_batch(batch, src_lang, tgt_lang)
95
  translated_parts.extend(translations)
96
+ all_scores.extend(scores)
97
 
98
  result = " ".join(translated_parts)
99
  elapsed = time.time() - start_time
100
 
101
+ avg_confidence = (sum(all_scores) / len(all_scores) * 100) if all_scores else 0
102
+ min_confidence = (min(all_scores) * 100) if all_scores else 0
103
+
104
  metrics = {
105
+ "confidence": round(avg_confidence, 1),
106
+ "min_confidence": round(min_confidence, 1),
107
+ "needs_review": avg_confidence < CONFIDENCE_THRESHOLD,
108
+ "time_seconds": round(elapsed, 2),
109
+ "sentences": len(sentences)
110
  }
111
 
112
  return result, metrics