kleervoyans commited on
Commit
768e15d
Β·
verified Β·
1 Parent(s): 7fc686c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -31
app.py CHANGED
@@ -1,8 +1,11 @@
 
 
1
  import streamlit as st
2
  import logging
3
  import pandas as pd
4
  import plotly.express as px
5
- from models.translation_loader import TranslationLoader
 
6
  from evaluators.evaluator import TranslationEvaluator
7
 
8
  # ────────── Logging ──────────
@@ -13,38 +16,39 @@ logging.basicConfig(
13
  )
14
  logger = logging.getLogger(__name__)
15
 
16
- # ────────── Cached Loader/Evaluator ──────────
17
  @st.cache_resource
18
  def load_resources():
19
- loader = TranslationLoader(
20
- model_name="facebook/nllb-200-distilled-600M",
21
- quantize=True
22
- )
23
  evaluator = TranslationEvaluator()
24
- return loader, evaluator
25
 
26
  # ────────── Sidebar Model Info ──────────
27
- def display_model_info(info):
28
  st.sidebar.markdown("### Model Info")
29
- st.sidebar.write(f"**Model:** {info['model_name']}")
30
- st.sidebar.write(f"**8-bit Quantized:** {info['quantized']}")
31
- st.sidebar.write(f"**Device:** {info['device']}")
 
32
 
33
  # ────────── Single‐text Processing ──────────
34
- def process_text(src, ref, loader, evaluator, metrics):
35
- # 1) Translate
36
- out = loader.translate(src, tgt_lang="tur_Latn")
37
  hyp = out[0]["translation_text"] if isinstance(out, list) else out["translation_text"]
38
  # 2) Evaluate
39
  scores = evaluator.evaluate([src], [ref or ""], [hyp])
40
  return {
41
- "source": src,
42
- "reference": ref,
43
  "hypothesis": hyp,
44
  **{m: scores[m] for m in metrics}
45
  }
46
 
47
- def _show_single_results(res):
48
  left, right = st.columns(2)
49
  with left:
50
  st.markdown("**Source:**")
@@ -56,13 +60,19 @@ def _show_single_results(res):
56
  st.write(res["reference"])
57
  with right:
58
  st.markdown("### Scores")
59
- df = pd.DataFrame({k: [v] for k, v in res.items() if k in ["BLEU","BERTScore","BERTurk","COMET"]})
60
  st.table(df)
61
 
62
  # ────────── Batch‐CSV Processing ──────────
63
- def process_file(uploaded, loader, evaluator, metrics, batch_size):
 
 
 
 
 
 
64
  df = pd.read_csv(uploaded)
65
- if not {"src","ref_tr"}.issubset(df.columns):
66
  raise ValueError("CSV must have `src` and `ref_tr` columns")
67
  prog = st.progress(0)
68
  results = []
@@ -72,9 +82,9 @@ def process_file(uploaded, loader, evaluator, metrics, batch_size):
72
  srcs = batch["src"].tolist()
73
  refs = batch["ref_tr"].tolist()
74
  # translate batch
75
- outs = loader.translate(srcs, tgt_lang="tur_Latn")
76
  hyps = [o["translation_text"] for o in outs]
77
- # evaluate each item individually
78
  for s, r, h in zip(srcs, refs, hyps):
79
  sc = evaluator.evaluate([s], [r], [h])
80
  entry = {"src": s, "ref_tr": r, "hyp_tr": h}
@@ -83,7 +93,7 @@ def process_file(uploaded, loader, evaluator, metrics, batch_size):
83
  prog.progress(min(i + batch_size, total) / total)
84
  return pd.DataFrame(results)
85
 
86
- def _show_batch_viz(df, metrics):
87
  for m in metrics:
88
  st.markdown(f"#### {m} Distribution")
89
  fig = px.histogram(df, x=m)
@@ -93,7 +103,7 @@ def _show_batch_viz(df, metrics):
93
  def main():
94
  st.set_page_config(page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide")
95
  st.title("πŸ”€ Translation β†’ TR Quality & COMET")
96
- st.markdown("Enter text or upload a CSV to translate into Turkish and evaluate with BLEU, BERTScore, BERTurk & COMET.")
97
 
98
  # Sidebar
99
  with st.sidebar:
@@ -101,11 +111,11 @@ def main():
101
  metrics = st.multiselect(
102
  "Select metrics",
103
  ["BLEU", "BERTScore", "BERTurk", "COMET"],
104
- default=["BLEU","BERTScore","COMET"]
105
  )
106
  batch_size = st.slider("Batch size", 1, 32, 8)
107
- loader, evaluator = load_resources()
108
- display_model_info(loader.get_info())
109
 
110
  # Tabs
111
  tab1, tab2 = st.tabs(["Single Sentence", "Batch CSV"])
@@ -115,22 +125,22 @@ def main():
115
  ref = st.text_area("Turkish reference (optional):", height=100)
116
  if st.button("Evaluate"):
117
  with st.spinner("Translating & evaluating…"):
118
- res = process_text(src, ref, loader, evaluator, metrics)
119
  _show_single_results(res)
120
 
121
  with tab2:
122
  uploaded = st.file_uploader("Upload CSV with `src` & `ref_tr` columns", type=["csv"])
123
  if uploaded:
124
  with st.spinner("Processing file…"):
125
- df_res = process_file(uploaded, loader, evaluator, metrics, batch_size)
126
  st.markdown("### Batch Results")
127
  st.dataframe(df_res, use_container_width=True)
128
  _show_batch_viz(df_res, metrics)
129
- st.download_button("Download CSV", df_res.to_csv(index=False), "results.csv")
130
 
131
  if __name__ == "__main__":
132
  try:
133
  main()
134
  except Exception as e:
135
  st.error(f"Unexpected error: {e}")
136
- logger.exception("Unhandled exception")
 
1
+ # app.py
2
+
3
  import streamlit as st
4
  import logging
5
  import pandas as pd
6
  import plotly.express as px
7
+
8
+ from models.model_manager import ModelManager
9
  from evaluators.evaluator import TranslationEvaluator
10
 
11
  # ────────── Logging ──────────
 
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
+ # ────────── Cached Resources ──────────
20
  @st.cache_resource
21
  def load_resources():
22
+ """
23
+ Load and cache the model manager and evaluator on first run.
24
+ """
25
+ manager = ModelManager(quantize=True)
26
  evaluator = TranslationEvaluator()
27
+ return manager, evaluator
28
 
29
  # ────────── Sidebar Model Info ──────────
30
+ def display_model_info(info: dict):
31
  st.sidebar.markdown("### Model Info")
32
+ st.sidebar.write(f"**Model:** {info.get('model')}")
33
+ st.sidebar.write(f"**8-bit Quantized:** {info.get('quantized')}")
34
+ st.sidebar.write(f"**Device:** {info.get('device')}")
35
+ st.sidebar.write(f"**Default target:** {info.get('default_tgt')}")
36
 
37
  # ────────── Single‐text Processing ──────────
38
+ def process_text(src: str, ref: str, manager: ModelManager, evaluator: TranslationEvaluator, metrics: list):
39
+ # 1) Translate (auto-detect source, default target Turkish)
40
+ out = manager.translate(src) # returns list of dicts
41
  hyp = out[0]["translation_text"] if isinstance(out, list) else out["translation_text"]
42
  # 2) Evaluate
43
  scores = evaluator.evaluate([src], [ref or ""], [hyp])
44
  return {
45
+ "source": src,
46
+ "reference": ref,
47
  "hypothesis": hyp,
48
  **{m: scores[m] for m in metrics}
49
  }
50
 
51
+ def _show_single_results(res: dict):
52
  left, right = st.columns(2)
53
  with left:
54
  st.markdown("**Source:**")
 
60
  st.write(res["reference"])
61
  with right:
62
  st.markdown("### Scores")
63
+ df = pd.DataFrame([{k: v for k, v in res.items() if k in ["BLEU","BERTScore","BERTurk","COMET"]}])
64
  st.table(df)
65
 
66
  # ────────── Batch‐CSV Processing ──────────
67
+ def process_file(
68
+ uploaded,
69
+ manager: ModelManager,
70
+ evaluator: TranslationEvaluator,
71
+ metrics: list,
72
+ batch_size: int
73
+ ):
74
  df = pd.read_csv(uploaded)
75
+ if not {"src", "ref_tr"}.issubset(df.columns):
76
  raise ValueError("CSV must have `src` and `ref_tr` columns")
77
  prog = st.progress(0)
78
  results = []
 
82
  srcs = batch["src"].tolist()
83
  refs = batch["ref_tr"].tolist()
84
  # translate batch
85
+ outs = manager.translate(srcs) # list of dicts
86
  hyps = [o["translation_text"] for o in outs]
87
+ # evaluate each row
88
  for s, r, h in zip(srcs, refs, hyps):
89
  sc = evaluator.evaluate([s], [r], [h])
90
  entry = {"src": s, "ref_tr": r, "hyp_tr": h}
 
93
  prog.progress(min(i + batch_size, total) / total)
94
  return pd.DataFrame(results)
95
 
96
+ def _show_batch_viz(df: pd.DataFrame, metrics: list):
97
  for m in metrics:
98
  st.markdown(f"#### {m} Distribution")
99
  fig = px.histogram(df, x=m)
 
103
  def main():
104
  st.set_page_config(page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide")
105
  st.title("πŸ”€ Translation β†’ TR Quality & COMET")
106
+ st.markdown("Translate any language into Turkish and evaluate with BLEU, BERTScore, BERTurk & COMET.")
107
 
108
  # Sidebar
109
  with st.sidebar:
 
111
  metrics = st.multiselect(
112
  "Select metrics",
113
  ["BLEU", "BERTScore", "BERTurk", "COMET"],
114
+ default=["BLEU", "BERTScore", "COMET"]
115
  )
116
  batch_size = st.slider("Batch size", 1, 32, 8)
117
+ manager, evaluator = load_resources()
118
+ display_model_info(manager.get_info())
119
 
120
  # Tabs
121
  tab1, tab2 = st.tabs(["Single Sentence", "Batch CSV"])
 
125
  ref = st.text_area("Turkish reference (optional):", height=100)
126
  if st.button("Evaluate"):
127
  with st.spinner("Translating & evaluating…"):
128
+ res = process_text(src, ref, manager, evaluator, metrics)
129
  _show_single_results(res)
130
 
131
  with tab2:
132
  uploaded = st.file_uploader("Upload CSV with `src` & `ref_tr` columns", type=["csv"])
133
  if uploaded:
134
  with st.spinner("Processing file…"):
135
+ df_res = process_file(uploaded, manager, evaluator, metrics, batch_size)
136
  st.markdown("### Batch Results")
137
  st.dataframe(df_res, use_container_width=True)
138
  _show_batch_viz(df_res, metrics)
139
+ st.download_button("Download results as CSV", df_res.to_csv(index=False), "results.csv")
140
 
141
  if __name__ == "__main__":
142
  try:
143
  main()
144
  except Exception as e:
145
  st.error(f"Unexpected error: {e}")
146
+ logger.exception("Unhandled exception in main()")