kleervoyans commited on
Commit
8ec855b
Β·
verified Β·
1 Parent(s): 9b88b5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -67
app.py CHANGED
@@ -1,7 +1,10 @@
 
 
1
  import streamlit as st
2
  import logging
3
  import pandas as pd
4
  import plotly.express as px
 
5
  from typing import Union, List
6
 
7
  from langdetect import detect, LangDetectException
@@ -22,36 +25,44 @@ logging.basicConfig(
22
  logger = logging.getLogger(__name__)
23
 
24
 
25
- # ────────── Model Management ──────────
26
  class ModelManager:
27
  """
28
- Automatically selects, loads, and wraps a seq2seq translation model
29
- in 8-bit (with FP32 fallback), plus language‐code auto-detection.
 
30
  """
31
-
32
  def __init__(
33
  self,
34
  candidates: List[str] = None,
35
  quantize: bool = True,
36
  default_tgt: str = None,
37
  ):
 
 
 
 
 
 
38
  self.candidates = candidates or [
39
  "facebook/nllb-200-distilled-600M",
40
  "facebook/m2m100_418M",
41
  ]
42
- self.quantize = quantize
43
- self.default_tgt = default_tgt # if None β†’ auto-pick Turkish
 
44
  self.tokenizer = None
45
  self.model = None
46
  self.pipeline = None
47
  self.lang_codes: List[str] = []
 
48
  self._select_and_load()
49
 
50
  def _select_and_load(self):
51
  last_err = None
52
  for model_name in self.candidates:
53
  try:
54
- # 1) Load tokenizer
55
  logger.info(f"Loading tokenizer for {model_name}")
56
  tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
57
  if not hasattr(tok, "lang_code_to_id"):
@@ -59,53 +70,55 @@ class ModelManager:
59
  f"Tokenizer for {model_name} missing lang_code_to_id"
60
  )
61
 
62
- # 2) Load model with bitsandbytes 8-bit quantization
63
  logger.info(
64
- f"Loading model {model_name} "
65
- f"(8-bit={'on' if self.quantize else 'off'})"
66
- )
67
- bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
68
- model = AutoModelForSeq2SeqLM.from_pretrained(
69
- model_name,
70
- device_map="auto",
71
- quantization_config=bnb_cfg,
72
  )
 
 
 
 
 
 
 
 
 
 
 
 
73
  logger.info(f"Model {model_name} loaded successfully")
74
 
75
- # 3) Build a translation pipeline around it
76
  pipe = pipeline(
77
  "translation",
78
- model=model,
79
  tokenizer=tok,
80
  )
81
 
82
- # 4) On success, store and break
 
83
  self.tokenizer = tok
84
- self.model = model
85
  self.pipeline = pipe
86
  self.lang_codes = list(tok.lang_code_to_id.keys())
87
- logger.info(f"Available language codes: {self.lang_codes[:5]}…")
88
 
89
- # 5) Auto-pick Turkish target if needed
90
  if not self.default_tgt:
91
- tur = [
92
- code
93
- for code in self.lang_codes
94
- if code.lower().startswith("tr")
95
  ]
96
- if not tur:
97
- raise ValueError(f"No Turkish code in {model_name}")
98
- self.default_tgt = tur[0]
99
  logger.info(f"Default target language: {self.default_tgt}")
100
 
101
  return
102
-
103
  except Exception as e:
104
  logger.warning(f"Failed to load {model_name}: {e}")
105
  last_err = e
106
 
107
  raise RuntimeError(
108
- f"Could not load any model from candidates {self.candidates}: {last_err}"
109
  )
110
 
111
  def translate(
@@ -116,13 +129,11 @@ class ModelManager:
116
  ):
117
  """
118
  Translate `text` from src_lang β†’ tgt_lang.
119
- If src_lang is None: auto-detect via langdetect.
120
- If tgt_lang is None: use default_tgt (Turkish).
121
- Returns the pipeline output (list of dicts with 'translation_text').
122
  """
123
  tgt = tgt_lang or self.default_tgt
124
 
125
- # Auto-detect source
126
  if not src_lang:
127
  sample = text[0] if isinstance(text, list) else text
128
  try:
@@ -132,41 +143,41 @@ class ModelManager:
132
  ]
133
  if not candidates:
134
  raise LangDetectException(f"No code for ISO '{iso}'")
135
- # prefer exact match
136
  exact = [c for c in candidates if c.lower() == iso]
137
  src = exact[0] if exact else candidates[0]
138
  logger.info(f"Auto-detected src_lang={src}")
139
  except Exception as e:
140
  logger.warning(f"langdetect failed ({e}); defaulting to English")
141
- eng = [c for c in self.lang_codes if c.lower().startswith("en")]
142
- src = eng[0] if eng else self.lang_codes[0]
 
 
143
  else:
144
  src = src_lang
145
 
146
- # Call the pipeline with both src_lang and tgt_lang
147
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
148
 
149
  def get_info(self):
150
- """Return metadata for sidebar display."""
151
- model = getattr(self.model, "config", None)
152
- quantized = getattr(self.model, "is_loaded_in_8bit", False)
153
- device = getattr(self.model.device, "index", None)
154
- device = f"cuda:{device}" if device is not None else "cpu"
155
  return {
156
- "model": self.model.name_or_path,
157
- "quantized": quantized,
158
- "device": device,
159
  "default_tgt": self.default_tgt,
160
  }
161
 
162
 
163
- # ────────── Evaluation ──────────
164
  class TranslationEvaluator:
165
  def __init__(self):
166
  self.bleu = evaluate.load("bleu")
167
  self.bertscore = evaluate.load("bertscore")
168
  self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
169
- logging.info("Loaded BLEU, BERTScore, COMET")
170
 
171
  def evaluate(
172
  self,
@@ -175,36 +186,27 @@ class TranslationEvaluator:
175
  predictions: List[str],
176
  ):
177
  results = {}
178
-
179
  # BLEU
180
  results["BLEU"] = self.bleu.compute(
181
  predictions=predictions,
182
  references=[[r] for r in references],
183
  )["bleu"]
184
-
185
  # BERTScore (general)
186
  bs = self.bertscore.compute(
187
  predictions=predictions, references=references, lang="xx"
188
  )
189
  results["BERTScore"] = sum(bs["f1"]) / len(bs["f1"]) if bs["f1"] else 0.0
190
-
191
  # BERTurk (Turkish)
192
  bs_tr = self.bertscore.compute(
193
  predictions=predictions, references=references, lang="tr"
194
  )
195
  results["BERTurk"] = sum(bs_tr["f1"]) / len(bs_tr["f1"]) if bs_tr["f1"] else 0.0
196
-
197
  # COMET
198
- co = self.comet.compute(
199
  srcs=sources, hyps=predictions, refs=references
200
  )
201
- # `scores` may be a float or list
202
- score = co.get("scores", None)
203
- if isinstance(score, list):
204
- results["COMET"] = score[0] if score else 0.0
205
- else:
206
- results["COMET"] = score or 0.0
207
-
208
  return results
209
 
210
 
@@ -212,9 +214,6 @@ class TranslationEvaluator:
212
 
213
  @st.cache_resource
214
  def load_resources():
215
- """
216
- Load and cache ModelManager & TranslationEvaluator on first run.
217
- """
218
  mgr = ModelManager(quantize=True)
219
  ev = TranslationEvaluator()
220
  return mgr, ev
@@ -235,7 +234,7 @@ def process_text(
235
  ev: TranslationEvaluator,
236
  metrics: List[str],
237
  ):
238
- out = mgr.translate(src) # list of dicts
239
  hyp = out[0]["translation_text"]
240
  scores = ev.evaluate([src], [ref or ""], [hyp])
241
  return {
@@ -258,7 +257,7 @@ def _show_single_results(res: dict):
258
  st.write(res["reference"])
259
  with right:
260
  st.markdown("### Scores")
261
- df = pd.DataFrame([{k: v for k, v in res.items() if k in res.keys() and k in ["BLEU","BERTScore","BERTurk","COMET"]}])
262
  st.table(df)
263
 
264
 
@@ -279,7 +278,7 @@ def process_file(
279
  batch = df.iloc[i : i + batch_size]
280
  srcs = batch["src"].tolist()
281
  refs = batch["ref_tr"].tolist()
282
- outs = mgr.translate(srcs) # batch translation
283
  hyps = [o["translation_text"] for o in outs]
284
  for s, r, h in zip(srcs, refs, hyps):
285
  sc = ev.evaluate([s], [r], [h])
 
1
+ # app.py
2
+
3
  import streamlit as st
4
  import logging
5
  import pandas as pd
6
  import plotly.express as px
7
+ import torch
8
  from typing import Union, List
9
 
10
  from langdetect import detect, LangDetectException
 
25
  logger = logging.getLogger(__name__)
26
 
27
 
28
+ # ────────── Model Manager ──────────
29
  class ModelManager:
30
  """
31
+ Selects and loads a translation model (NLLB-200 or M2M100),
32
+ using 8-bit quantization only if CUDA is available.
33
+ Auto-detects source language and defaults target to Turkish.
34
  """
 
35
  def __init__(
36
  self,
37
  candidates: List[str] = None,
38
  quantize: bool = True,
39
  default_tgt: str = None,
40
  ):
41
+ # If user requested quantization but CUDA isn't available, disable it
42
+ if quantize and not torch.cuda.is_available():
43
+ logger.warning("CUDA unavailable; disabling 8-bit quantization")
44
+ quantize = False
45
+ self.quantize = quantize
46
+
47
  self.candidates = candidates or [
48
  "facebook/nllb-200-distilled-600M",
49
  "facebook/m2m100_418M",
50
  ]
51
+ self.default_tgt = default_tgt # will auto-pick if None
52
+
53
+ self.selected_model_name: str = None
54
  self.tokenizer = None
55
  self.model = None
56
  self.pipeline = None
57
  self.lang_codes: List[str] = []
58
+
59
  self._select_and_load()
60
 
61
  def _select_and_load(self):
62
  last_err = None
63
  for model_name in self.candidates:
64
  try:
65
+ # Load tokenizer
66
  logger.info(f"Loading tokenizer for {model_name}")
67
  tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
68
  if not hasattr(tok, "lang_code_to_id"):
 
70
  f"Tokenizer for {model_name} missing lang_code_to_id"
71
  )
72
 
73
+ # Load model (with or without 8-bit)
74
  logger.info(
75
+ f"Loading model {model_name} (8-bit={self.quantize})"
 
 
 
 
 
 
 
76
  )
77
+ if self.quantize:
78
+ bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
79
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(
80
+ model_name,
81
+ device_map="auto",
82
+ quantization_config=bnb_cfg,
83
+ )
84
+ else:
85
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(
86
+ model_name,
87
+ device_map="auto",
88
+ )
89
  logger.info(f"Model {model_name} loaded successfully")
90
 
91
+ # Wrap in a translation pipeline
92
  pipe = pipeline(
93
  "translation",
94
+ model=mdl,
95
  tokenizer=tok,
96
  )
97
 
98
+ # Store and break
99
+ self.selected_model_name = model_name
100
  self.tokenizer = tok
101
+ self.model = mdl
102
  self.pipeline = pipe
103
  self.lang_codes = list(tok.lang_code_to_id.keys())
 
104
 
105
+ # Auto-pick Turkish target code if none specified
106
  if not self.default_tgt:
107
+ tur_codes = [
108
+ c for c in self.lang_codes if c.lower().startswith("tr")
 
 
109
  ]
110
+ if not tur_codes:
111
+ raise ValueError(f"No Turkish code found in {model_name}")
112
+ self.default_tgt = tur_codes[0]
113
  logger.info(f"Default target language: {self.default_tgt}")
114
 
115
  return
 
116
  except Exception as e:
117
  logger.warning(f"Failed to load {model_name}: {e}")
118
  last_err = e
119
 
120
  raise RuntimeError(
121
+ f"Could not load any model from {self.candidates}: {last_err}"
122
  )
123
 
124
  def translate(
 
129
  ):
130
  """
131
  Translate `text` from src_lang β†’ tgt_lang.
132
+ Auto-detects src_lang if not given.
 
 
133
  """
134
  tgt = tgt_lang or self.default_tgt
135
 
136
+ # Auto-detect source language if missing
137
  if not src_lang:
138
  sample = text[0] if isinstance(text, list) else text
139
  try:
 
143
  ]
144
  if not candidates:
145
  raise LangDetectException(f"No code for ISO '{iso}'")
 
146
  exact = [c for c in candidates if c.lower() == iso]
147
  src = exact[0] if exact else candidates[0]
148
  logger.info(f"Auto-detected src_lang={src}")
149
  except Exception as e:
150
  logger.warning(f"langdetect failed ({e}); defaulting to English")
151
+ eng_codes = [
152
+ c for c in self.lang_codes if c.lower().startswith("en")
153
+ ]
154
+ src = eng_codes[0] if eng_codes else self.lang_codes[0]
155
  else:
156
  src = src_lang
157
 
 
158
  return self.pipeline(text, src_lang=src, tgt_lang=tgt)
159
 
160
  def get_info(self):
161
+ """Return metadata for the sidebar display."""
162
+ device = "cpu"
163
+ if torch.cuda.is_available() and hasattr(self.model, "device"):
164
+ idx = self.model.device.index if hasattr(self.model.device, "index") else None
165
+ device = f"cuda:{idx}" if idx is not None else "cuda"
166
  return {
167
+ "model": self.selected_model_name,
168
+ "quantized": self.quantize,
169
+ "device": device,
170
  "default_tgt": self.default_tgt,
171
  }
172
 
173
 
174
+ # ────────── Evaluator ──────────
175
  class TranslationEvaluator:
176
  def __init__(self):
177
  self.bleu = evaluate.load("bleu")
178
  self.bertscore = evaluate.load("bertscore")
179
  self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
180
+ logger.info("Loaded BLEU, BERTScore, COMET metrics")
181
 
182
  def evaluate(
183
  self,
 
186
  predictions: List[str],
187
  ):
188
  results = {}
 
189
  # BLEU
190
  results["BLEU"] = self.bleu.compute(
191
  predictions=predictions,
192
  references=[[r] for r in references],
193
  )["bleu"]
 
194
  # BERTScore (general)
195
  bs = self.bertscore.compute(
196
  predictions=predictions, references=references, lang="xx"
197
  )
198
  results["BERTScore"] = sum(bs["f1"]) / len(bs["f1"]) if bs["f1"] else 0.0
 
199
  # BERTurk (Turkish)
200
  bs_tr = self.bertscore.compute(
201
  predictions=predictions, references=references, lang="tr"
202
  )
203
  results["BERTurk"] = sum(bs_tr["f1"]) / len(bs_tr["f1"]) if bs_tr["f1"] else 0.0
 
204
  # COMET
205
+ cm = self.comet.compute(
206
  srcs=sources, hyps=predictions, refs=references
207
  )
208
+ scores = cm.get("scores", None)
209
+ results["COMET"] = float(scores[0] if isinstance(scores, list) else scores) or 0.0
 
 
 
 
 
210
  return results
211
 
212
 
 
214
 
215
  @st.cache_resource
216
  def load_resources():
 
 
 
217
  mgr = ModelManager(quantize=True)
218
  ev = TranslationEvaluator()
219
  return mgr, ev
 
234
  ev: TranslationEvaluator,
235
  metrics: List[str],
236
  ):
237
+ out = mgr.translate(src)
238
  hyp = out[0]["translation_text"]
239
  scores = ev.evaluate([src], [ref or ""], [hyp])
240
  return {
 
257
  st.write(res["reference"])
258
  with right:
259
  st.markdown("### Scores")
260
+ df = pd.DataFrame([{k: v for k, v in res.items() if k in metrics}])
261
  st.table(df)
262
 
263
 
 
278
  batch = df.iloc[i : i + batch_size]
279
  srcs = batch["src"].tolist()
280
  refs = batch["ref_tr"].tolist()
281
+ outs = mgr.translate(srcs)
282
  hyps = [o["translation_text"] for o in outs]
283
  for s, r, h in zip(srcs, refs, hyps):
284
  sc = ev.evaluate([s], [r], [h])