kleervoyans commited on
Commit
9b88b5f
Β·
verified Β·
1 Parent(s): 768e15d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +252 -46
app.py CHANGED
@@ -1,53 +1,251 @@
1
- # app.py
2
-
3
  import streamlit as st
4
  import logging
5
  import pandas as pd
6
  import plotly.express as px
 
7
 
8
- from models.model_manager import ModelManager
9
- from evaluators.evaluator import TranslationEvaluator
 
 
 
 
 
 
10
 
11
  # ────────── Logging ──────────
12
  logging.basicConfig(
13
  format="%(asctime)s %(levelname)s %(name)s: %(message)s",
14
  datefmt="%Y-%m-%d %H:%M:%S",
15
- level=logging.INFO
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
- # ────────── Cached Resources ──────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @st.cache_resource
21
  def load_resources():
22
  """
23
- Load and cache the model manager and evaluator on first run.
24
  """
25
- manager = ModelManager(quantize=True)
26
- evaluator = TranslationEvaluator()
27
- return manager, evaluator
 
28
 
29
- # ────────── Sidebar Model Info ──────────
30
  def display_model_info(info: dict):
31
  st.sidebar.markdown("### Model Info")
32
- st.sidebar.write(f"**Model:** {info.get('model')}")
33
- st.sidebar.write(f"**8-bit Quantized:** {info.get('quantized')}")
34
- st.sidebar.write(f"**Device:** {info.get('device')}")
35
- st.sidebar.write(f"**Default target:** {info.get('default_tgt')}")
36
-
37
- # ────────── Single‐text Processing ──────────
38
- def process_text(src: str, ref: str, manager: ModelManager, evaluator: TranslationEvaluator, metrics: list):
39
- # 1) Translate (auto-detect source, default target Turkish)
40
- out = manager.translate(src) # returns list of dicts
41
- hyp = out[0]["translation_text"] if isinstance(out, list) else out["translation_text"]
42
- # 2) Evaluate
43
- scores = evaluator.evaluate([src], [ref or ""], [hyp])
 
 
 
 
44
  return {
45
  "source": src,
46
  "reference": ref,
47
  "hypothesis": hyp,
48
- **{m: scores[m] for m in metrics}
49
  }
50
 
 
51
  def _show_single_results(res: dict):
52
  left, right = st.columns(2)
53
  with left:
@@ -60,16 +258,16 @@ def _show_single_results(res: dict):
60
  st.write(res["reference"])
61
  with right:
62
  st.markdown("### Scores")
63
- df = pd.DataFrame([{k: v for k, v in res.items() if k in ["BLEU","BERTScore","BERTurk","COMET"]}])
64
  st.table(df)
65
 
66
- # ────────── Batch‐CSV Processing ──────────
67
  def process_file(
68
  uploaded,
69
- manager: ModelManager,
70
- evaluator: TranslationEvaluator,
71
- metrics: list,
72
- batch_size: int
73
  ):
74
  df = pd.read_csv(uploaded)
75
  if not {"src", "ref_tr"}.issubset(df.columns):
@@ -81,29 +279,32 @@ def process_file(
81
  batch = df.iloc[i : i + batch_size]
82
  srcs = batch["src"].tolist()
83
  refs = batch["ref_tr"].tolist()
84
- # translate batch
85
- outs = manager.translate(srcs) # list of dicts
86
  hyps = [o["translation_text"] for o in outs]
87
- # evaluate each row
88
  for s, r, h in zip(srcs, refs, hyps):
89
- sc = evaluator.evaluate([s], [r], [h])
90
  entry = {"src": s, "ref_tr": r, "hyp_tr": h}
91
  entry.update({m: sc[m] for m in metrics})
92
  results.append(entry)
93
  prog.progress(min(i + batch_size, total) / total)
94
  return pd.DataFrame(results)
95
 
96
- def _show_batch_viz(df: pd.DataFrame, metrics: list):
 
97
  for m in metrics:
98
  st.markdown(f"#### {m} Distribution")
99
  fig = px.histogram(df, x=m)
100
  st.plotly_chart(fig, use_container_width=True)
101
 
102
- # ────────── Main ──────────
103
  def main():
104
- st.set_page_config(page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide")
 
 
105
  st.title("πŸ”€ Translation β†’ TR Quality & COMET")
106
- st.markdown("Translate any language into Turkish and evaluate with BLEU, BERTScore, BERTurk & COMET.")
 
 
107
 
108
  # Sidebar
109
  with st.sidebar:
@@ -111,11 +312,11 @@ def main():
111
  metrics = st.multiselect(
112
  "Select metrics",
113
  ["BLEU", "BERTScore", "BERTurk", "COMET"],
114
- default=["BLEU", "BERTScore", "COMET"]
115
  )
116
  batch_size = st.slider("Batch size", 1, 32, 8)
117
- manager, evaluator = load_resources()
118
- display_model_info(manager.get_info())
119
 
120
  # Tabs
121
  tab1, tab2 = st.tabs(["Single Sentence", "Batch CSV"])
@@ -125,22 +326,27 @@ def main():
125
  ref = st.text_area("Turkish reference (optional):", height=100)
126
  if st.button("Evaluate"):
127
  with st.spinner("Translating & evaluating…"):
128
- res = process_text(src, ref, manager, evaluator, metrics)
129
  _show_single_results(res)
130
 
131
  with tab2:
132
- uploaded = st.file_uploader("Upload CSV with `src` & `ref_tr` columns", type=["csv"])
 
 
133
  if uploaded:
134
  with st.spinner("Processing file…"):
135
- df_res = process_file(uploaded, manager, evaluator, metrics, batch_size)
136
  st.markdown("### Batch Results")
137
  st.dataframe(df_res, use_container_width=True)
138
  _show_batch_viz(df_res, metrics)
139
- st.download_button("Download results as CSV", df_res.to_csv(index=False), "results.csv")
 
 
 
140
 
141
  if __name__ == "__main__":
142
  try:
143
  main()
144
  except Exception as e:
145
  st.error(f"Unexpected error: {e}")
146
- logger.exception("Unhandled exception in main()")
 
 
 
1
  import streamlit as st
2
  import logging
3
  import pandas as pd
4
  import plotly.express as px
5
+ from typing import Union, List
6
 
7
+ from langdetect import detect, LangDetectException
8
+ from transformers import (
9
+ AutoTokenizer,
10
+ AutoModelForSeq2SeqLM,
11
+ pipeline,
12
+ BitsAndBytesConfig,
13
+ )
14
+ import evaluate
15
 
16
  # ────────── Logging ──────────
17
  logging.basicConfig(
18
  format="%(asctime)s %(levelname)s %(name)s: %(message)s",
19
  datefmt="%Y-%m-%d %H:%M:%S",
20
+ level=logging.INFO,
21
  )
22
  logger = logging.getLogger(__name__)
23
 
24
+
25
+ # ────────── Model Management ──────────
26
+ class ModelManager:
27
+ """
28
+ Automatically selects, loads, and wraps a seq2seq translation model
29
+ in 8-bit (with FP32 fallback), plus language‐code auto-detection.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ candidates: List[str] = None,
35
+ quantize: bool = True,
36
+ default_tgt: str = None,
37
+ ):
38
+ self.candidates = candidates or [
39
+ "facebook/nllb-200-distilled-600M",
40
+ "facebook/m2m100_418M",
41
+ ]
42
+ self.quantize = quantize
43
+ self.default_tgt = default_tgt # if None β†’ auto-pick Turkish
44
+ self.tokenizer = None
45
+ self.model = None
46
+ self.pipeline = None
47
+ self.lang_codes: List[str] = []
48
+ self._select_and_load()
49
+
50
+ def _select_and_load(self):
51
+ last_err = None
52
+ for model_name in self.candidates:
53
+ try:
54
+ # 1) Load tokenizer
55
+ logger.info(f"Loading tokenizer for {model_name}")
56
+ tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
57
+ if not hasattr(tok, "lang_code_to_id"):
58
+ raise AttributeError(
59
+ f"Tokenizer for {model_name} missing lang_code_to_id"
60
+ )
61
+
62
+ # 2) Load model with bitsandbytes 8-bit quantization
63
+ logger.info(
64
+ f"Loading model {model_name} "
65
+ f"(8-bit={'on' if self.quantize else 'off'})"
66
+ )
67
+ bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
68
+ model = AutoModelForSeq2SeqLM.from_pretrained(
69
+ model_name,
70
+ device_map="auto",
71
+ quantization_config=bnb_cfg,
72
+ )
73
+ logger.info(f"Model {model_name} loaded successfully")
74
+
75
+ # 3) Build a translation pipeline around it
76
+ pipe = pipeline(
77
+ "translation",
78
+ model=model,
79
+ tokenizer=tok,
80
+ )
81
+
82
+ # 4) On success, store and break
83
+ self.tokenizer = tok
84
+ self.model = model
85
+ self.pipeline = pipe
86
+ self.lang_codes = list(tok.lang_code_to_id.keys())
87
+ logger.info(f"Available language codes: {self.lang_codes[:5]}…")
88
+
89
+ # 5) Auto-pick Turkish target if needed
90
+ if not self.default_tgt:
91
+ tur = [
92
+ code
93
+ for code in self.lang_codes
94
+ if code.lower().startswith("tr")
95
+ ]
96
+ if not tur:
97
+ raise ValueError(f"No Turkish code in {model_name}")
98
+ self.default_tgt = tur[0]
99
+ logger.info(f"Default target language: {self.default_tgt}")
100
+
101
+ return
102
+
103
+ except Exception as e:
104
+ logger.warning(f"Failed to load {model_name}: {e}")
105
+ last_err = e
106
+
107
+ raise RuntimeError(
108
+ f"Could not load any model from candidates {self.candidates}: {last_err}"
109
+ )
110
+
111
+ def translate(
112
+ self,
113
+ text: Union[str, List[str]],
114
+ src_lang: str = None,
115
+ tgt_lang: str = None,
116
+ ):
117
+ """
118
+ Translate `text` from src_lang β†’ tgt_lang.
119
+ If src_lang is None: auto-detect via langdetect.
120
+ If tgt_lang is None: use default_tgt (Turkish).
121
+ Returns the pipeline output (list of dicts with 'translation_text').
122
+ """
123
+ tgt = tgt_lang or self.default_tgt
124
+
125
+ # Auto-detect source
126
+ if not src_lang:
127
+ sample = text[0] if isinstance(text, list) else text
128
+ try:
129
+ iso = detect(sample).lower()
130
+ candidates = [
131
+ c for c in self.lang_codes if c.lower().startswith(iso)
132
+ ]
133
+ if not candidates:
134
+ raise LangDetectException(f"No code for ISO '{iso}'")
135
+ # prefer exact match
136
+ exact = [c for c in candidates if c.lower() == iso]
137
+ src = exact[0] if exact else candidates[0]
138
+ logger.info(f"Auto-detected src_lang={src}")
139
+ except Exception as e:
140
+ logger.warning(f"langdetect failed ({e}); defaulting to English")
141
+ eng = [c for c in self.lang_codes if c.lower().startswith("en")]
142
+ src = eng[0] if eng else self.lang_codes[0]
143
+ else:
144
+ src = src_lang
145
+
146
+ # Call the pipeline with both src_lang and tgt_lang
147
+ return self.pipeline(text, src_lang=src, tgt_lang=tgt)
148
+
149
+ def get_info(self):
150
+ """Return metadata for sidebar display."""
151
+ model = getattr(self.model, "config", None)
152
+ quantized = getattr(self.model, "is_loaded_in_8bit", False)
153
+ device = getattr(self.model.device, "index", None)
154
+ device = f"cuda:{device}" if device is not None else "cpu"
155
+ return {
156
+ "model": self.model.name_or_path,
157
+ "quantized": quantized,
158
+ "device": device,
159
+ "default_tgt": self.default_tgt,
160
+ }
161
+
162
+
163
+ # ────────── Evaluation ──────────
164
+ class TranslationEvaluator:
165
+ def __init__(self):
166
+ self.bleu = evaluate.load("bleu")
167
+ self.bertscore = evaluate.load("bertscore")
168
+ self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
169
+ logging.info("Loaded BLEU, BERTScore, COMET")
170
+
171
+ def evaluate(
172
+ self,
173
+ sources: List[str],
174
+ references: List[str],
175
+ predictions: List[str],
176
+ ):
177
+ results = {}
178
+
179
+ # BLEU
180
+ results["BLEU"] = self.bleu.compute(
181
+ predictions=predictions,
182
+ references=[[r] for r in references],
183
+ )["bleu"]
184
+
185
+ # BERTScore (general)
186
+ bs = self.bertscore.compute(
187
+ predictions=predictions, references=references, lang="xx"
188
+ )
189
+ results["BERTScore"] = sum(bs["f1"]) / len(bs["f1"]) if bs["f1"] else 0.0
190
+
191
+ # BERTurk (Turkish)
192
+ bs_tr = self.bertscore.compute(
193
+ predictions=predictions, references=references, lang="tr"
194
+ )
195
+ results["BERTurk"] = sum(bs_tr["f1"]) / len(bs_tr["f1"]) if bs_tr["f1"] else 0.0
196
+
197
+ # COMET
198
+ co = self.comet.compute(
199
+ srcs=sources, hyps=predictions, refs=references
200
+ )
201
+ # `scores` may be a float or list
202
+ score = co.get("scores", None)
203
+ if isinstance(score, list):
204
+ results["COMET"] = score[0] if score else 0.0
205
+ else:
206
+ results["COMET"] = score or 0.0
207
+
208
+ return results
209
+
210
+
211
+ # ────────── Streamlit App ──────────
212
+
213
  @st.cache_resource
214
  def load_resources():
215
  """
216
+ Load and cache ModelManager & TranslationEvaluator on first run.
217
  """
218
+ mgr = ModelManager(quantize=True)
219
+ ev = TranslationEvaluator()
220
+ return mgr, ev
221
+
222
 
 
223
  def display_model_info(info: dict):
224
  st.sidebar.markdown("### Model Info")
225
+ st.sidebar.write(f"**Model:** {info['model']}")
226
+ st.sidebar.write(f"**8-bit Quantized:** {info['quantized']}")
227
+ st.sidebar.write(f"**Device:** {info['device']}")
228
+ st.sidebar.write(f"**Default target:** {info['default_tgt']}")
229
+
230
+
231
+ def process_text(
232
+ src: str,
233
+ ref: str,
234
+ mgr: ModelManager,
235
+ ev: TranslationEvaluator,
236
+ metrics: List[str],
237
+ ):
238
+ out = mgr.translate(src) # list of dicts
239
+ hyp = out[0]["translation_text"]
240
+ scores = ev.evaluate([src], [ref or ""], [hyp])
241
  return {
242
  "source": src,
243
  "reference": ref,
244
  "hypothesis": hyp,
245
+ **{m: scores[m] for m in metrics},
246
  }
247
 
248
+
249
  def _show_single_results(res: dict):
250
  left, right = st.columns(2)
251
  with left:
 
258
  st.write(res["reference"])
259
  with right:
260
  st.markdown("### Scores")
261
+ df = pd.DataFrame([{k: v for k, v in res.items() if k in res.keys() and k in ["BLEU","BERTScore","BERTurk","COMET"]}])
262
  st.table(df)
263
 
264
+
265
  def process_file(
266
  uploaded,
267
+ mgr: ModelManager,
268
+ ev: TranslationEvaluator,
269
+ metrics: List[str],
270
+ batch_size: int,
271
  ):
272
  df = pd.read_csv(uploaded)
273
  if not {"src", "ref_tr"}.issubset(df.columns):
 
279
  batch = df.iloc[i : i + batch_size]
280
  srcs = batch["src"].tolist()
281
  refs = batch["ref_tr"].tolist()
282
+ outs = mgr.translate(srcs) # batch translation
 
283
  hyps = [o["translation_text"] for o in outs]
 
284
  for s, r, h in zip(srcs, refs, hyps):
285
+ sc = ev.evaluate([s], [r], [h])
286
  entry = {"src": s, "ref_tr": r, "hyp_tr": h}
287
  entry.update({m: sc[m] for m in metrics})
288
  results.append(entry)
289
  prog.progress(min(i + batch_size, total) / total)
290
  return pd.DataFrame(results)
291
 
292
+
293
+ def _show_batch_viz(df: pd.DataFrame, metrics: List[str]):
294
  for m in metrics:
295
  st.markdown(f"#### {m} Distribution")
296
  fig = px.histogram(df, x=m)
297
  st.plotly_chart(fig, use_container_width=True)
298
 
299
+
300
  def main():
301
+ st.set_page_config(
302
+ page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide"
303
+ )
304
  st.title("πŸ”€ Translation β†’ TR Quality & COMET")
305
+ st.markdown(
306
+ "Translate any language into Turkish and evaluate with BLEU, BERTScore, BERTurk & COMET."
307
+ )
308
 
309
  # Sidebar
310
  with st.sidebar:
 
312
  metrics = st.multiselect(
313
  "Select metrics",
314
  ["BLEU", "BERTScore", "BERTurk", "COMET"],
315
+ default=["BLEU", "BERTScore", "COMET"],
316
  )
317
  batch_size = st.slider("Batch size", 1, 32, 8)
318
+ mgr, ev = load_resources()
319
+ display_model_info(mgr.get_info())
320
 
321
  # Tabs
322
  tab1, tab2 = st.tabs(["Single Sentence", "Batch CSV"])
 
326
  ref = st.text_area("Turkish reference (optional):", height=100)
327
  if st.button("Evaluate"):
328
  with st.spinner("Translating & evaluating…"):
329
+ res = process_text(src, ref, mgr, ev, metrics)
330
  _show_single_results(res)
331
 
332
  with tab2:
333
+ uploaded = st.file_uploader(
334
+ "Upload CSV with `src` & `ref_tr` columns", type=["csv"]
335
+ )
336
  if uploaded:
337
  with st.spinner("Processing file…"):
338
+ df_res = process_file(uploaded, mgr, ev, metrics, batch_size)
339
  st.markdown("### Batch Results")
340
  st.dataframe(df_res, use_container_width=True)
341
  _show_batch_viz(df_res, metrics)
342
+ st.download_button(
343
+ "Download CSV", df_res.to_csv(index=False), "results.csv"
344
+ )
345
+
346
 
347
  if __name__ == "__main__":
348
  try:
349
  main()
350
  except Exception as e:
351
  st.error(f"Unexpected error: {e}")
352
+ logger.exception("Unhandled exception")