kleervoyans commited on
Commit
c631abc
Β·
verified Β·
1 Parent(s): c4d24a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -72
app.py CHANGED
@@ -1,86 +1,136 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
- import evaluate
 
 
 
4
 
5
- # Page configuration
6
- st.set_page_config(
7
- page_title="Turkish Translation Evaluator",
8
- layout="wide",
9
- initial_sidebar_state="expanded"
10
  )
 
11
 
 
12
  @st.cache_resource
13
- def load_translation_pipeline():
14
- model_name = "facebook/m2m100_418M"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForSeq2SeqLM.from_pretrained(
17
- model_name,
18
- device_map="auto",
19
- load_in_8bit=True,
20
- torch_dtype="auto"
21
  )
22
- translator = pipeline(
23
- "translation",
24
- model=model,
25
- tokenizer=tokenizer,
26
- src_lang="auto",
27
- tgt_lang="tr",
28
- device_map="auto"
29
- )
30
- return translator
31
 
32
- @st.cache_resource
33
- def load_metrics():
 
 
 
 
 
 
 
 
 
 
 
 
34
  return {
35
- "bleu": evaluate.load("bleu"),
36
- "bertscore": evaluate.load("bertscore"),
37
- "bertturk": evaluate.load("bertscore"),
38
- "comet": evaluate.load("comet", module_type="metric")
39
  }
40
 
41
- translator = load_translation_pipeline()
42
- metrics = load_metrics()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- st.title("πŸ”€ Turkish Translation & Evaluation")
 
45
 
46
- with st.form("translate_form"):
47
- input_text = st.text_area("Input text (any language)", height=150)
48
- ref_text = st.text_area("Reference Turkish translation (optional)", height=150)
49
- submit = st.form_submit_button("Translate & Evaluate")
 
 
 
50
 
51
- if submit:
52
- if not input_text.strip():
53
- st.error("Please provide input text to translate.")
54
- else:
55
- with st.spinner("Translating..."):
56
- out = translator(input_text, max_length=512)
57
- translation = out[0]["translation_text"]
58
- st.subheader("Model Translation (Turkish)")
59
- st.markdown(f"> {translation}")
60
 
61
- if ref_text.strip():
62
- preds = [translation]
63
- refs = [[ref_text]]
64
- # BLEU
65
- bleu_score = metrics["bleu"].compute(predictions=preds, references=refs)["bleu"] * 100
66
- st.metric("BLEU-4", f"{bleu_score:.2f}")
67
- # BERTScore (multilingual)
68
- bs = metrics["bertscore"].compute(predictions=preds, references=[ref_text], lang="tr")
69
- st.metric("BERTScore (f1)", f"{bs['f1'][0]*100:.2f}")
70
- # BERTurk (Turkish BERTScore)
71
- bt = metrics["bertturk"].compute(
72
- predictions=preds,
73
- references=[ref_text],
74
- model_type="dbmdz/bert-base-turkish-cased"
75
- )
76
- st.metric("BERTurk (f1)", f"{bt['f1'][0]*100:.2f}")
77
- # COMET
78
- comet_out = metrics["comet"].compute(
79
- model="Unbabel/wmt22-comet-da",
80
- src=[input_text],
81
- mt=preds,
82
- ref=[ref_text]
83
- )
84
- st.metric("COMET", f"{comet_out['score'][0]:.2f}")
85
- else:
86
- st.info("No reference provided; skipping evaluation metrics.")
 
1
  import streamlit as st
2
+ import logging
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ from models.translation_loader import TranslationLoader
6
+ from evaluators.evaluator import TranslationEvaluator
7
 
8
+ # ────────── Logging ──────────
9
+ logging.basicConfig(
10
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
11
+ datefmt="%Y-%m-%d %H:%M:%S",
12
+ level=logging.INFO
13
  )
14
+ logger = logging.getLogger(__name__)
15
 
16
+ # ────────── Cached Loader/Evaluator ──────────
17
  @st.cache_resource
18
+ def load_resources():
19
+ loader = TranslationLoader(
20
+ model_name="facebook/nllb-200-distilled-600M",
21
+ quantize=True
 
 
 
 
22
  )
23
+ evaluator = TranslationEvaluator()
24
+ return loader, evaluator
 
 
 
 
 
 
 
25
 
26
+ # ────────── Sidebar Model Info ──────────
27
+ def display_model_info(info):
28
+ st.sidebar.markdown("### Model Info")
29
+ st.sidebar.write(f"**Model:** {info['model_name']}")
30
+ st.sidebar.write(f"**8-bit Quantized:** {info['quantized']}")
31
+ st.sidebar.write(f"**Device:** {info['device']}")
32
+
33
+ # ────────── Single‐text Processing ──────────
34
+ def process_text(src, ref, loader, evaluator, metrics):
35
+ # 1) Translate
36
+ out = loader.translate(src, tgt_lang="tur_Latn")
37
+ hyp = out[0]["translation_text"] if isinstance(out, list) else out["translation_text"]
38
+ # 2) Evaluate
39
+ scores = evaluator.evaluate([src], [ref or ""], [hyp])
40
  return {
41
+ "source": src,
42
+ "reference": ref,
43
+ "hypothesis": hyp,
44
+ **{m: scores[m] for m in metrics}
45
  }
46
 
47
+ def _show_single_results(res):
48
+ left, right = st.columns(2)
49
+ with left:
50
+ st.markdown("**Source:**")
51
+ st.write(res["source"])
52
+ st.markdown("**Hypothesis (TR):**")
53
+ st.write(res["hypothesis"])
54
+ if res["reference"]:
55
+ st.markdown("**Reference (TR):**")
56
+ st.write(res["reference"])
57
+ with right:
58
+ st.markdown("### Scores")
59
+ df = pd.DataFrame({k: [v] for k, v in res.items() if k in ["BLEU","BERTScore","BERTurk","COMET"]})
60
+ st.table(df)
61
+
62
+ # ────────── Batch‐CSV Processing ──────────
63
+ def process_file(uploaded, loader, evaluator, metrics, batch_size):
64
+ df = pd.read_csv(uploaded)
65
+ if not {"src","ref_tr"}.issubset(df.columns):
66
+ raise ValueError("CSV must have `src` and `ref_tr` columns")
67
+ prog = st.progress(0)
68
+ results = []
69
+ total = len(df)
70
+ for i in range(0, total, batch_size):
71
+ batch = df.iloc[i : i + batch_size]
72
+ srcs = batch["src"].tolist()
73
+ refs = batch["ref_tr"].tolist()
74
+ # translate batch
75
+ outs = loader.translate(srcs, tgt_lang="tur_Latn")
76
+ hyps = [o["translation_text"] for o in outs]
77
+ # evaluate each item individually
78
+ for s, r, h in zip(srcs, refs, hyps):
79
+ sc = evaluator.evaluate([s], [r], [h])
80
+ entry = {"src": s, "ref_tr": r, "hyp_tr": h}
81
+ entry.update({m: sc[m] for m in metrics})
82
+ results.append(entry)
83
+ prog.progress(min(i + batch_size, total) / total)
84
+ return pd.DataFrame(results)
85
+
86
+ def _show_batch_viz(df, metrics):
87
+ for m in metrics:
88
+ st.markdown(f"#### {m} Distribution")
89
+ fig = px.histogram(df, x=m)
90
+ st.plotly_chart(fig, use_container_width=True)
91
+
92
+ # ────────── Main ──────────
93
+ def main():
94
+ st.set_page_config(page_title="πŸ”€ Translationβ†’Turkish Quality", layout="wide")
95
+ st.title("πŸ”€ Translation β†’ TR Quality & COMET")
96
+ st.markdown("Enter text or upload a CSV to translate into Turkish and evaluate with BLEU, BERTScore, BERTurk & COMET.")
97
+
98
+ # Sidebar
99
+ with st.sidebar:
100
+ st.header("Settings")
101
+ metrics = st.multiselect(
102
+ "Select metrics",
103
+ ["BLEU", "BERTScore", "BERTurk", "COMET"],
104
+ default=["BLEU","BERTScore","COMET"]
105
+ )
106
+ batch_size = st.slider("Batch size", 1, 32, 8)
107
+ loader, evaluator = load_resources()
108
+ display_model_info(loader.get_info())
109
 
110
+ # Tabs
111
+ tab1, tab2 = st.tabs(["Single Sentence", "Batch CSV"])
112
 
113
+ with tab1:
114
+ src = st.text_area("Source sentence (any language):", height=150)
115
+ ref = st.text_area("Turkish reference (optional):", height=100)
116
+ if st.button("Evaluate"):
117
+ with st.spinner("Translating & evaluating…"):
118
+ res = process_text(src, ref, loader, evaluator, metrics)
119
+ _show_single_results(res)
120
 
121
+ with tab2:
122
+ uploaded = st.file_uploader("Upload CSV with `src` & `ref_tr` columns", type=["csv"])
123
+ if uploaded:
124
+ with st.spinner("Processing file…"):
125
+ df_res = process_file(uploaded, loader, evaluator, metrics, batch_size)
126
+ st.markdown("### Batch Results")
127
+ st.dataframe(df_res, use_container_width=True)
128
+ _show_batch_viz(df_res, metrics)
129
+ st.download_button("Download CSV", df_res.to_csv(index=False), "results.csv")
130
 
131
+ if __name__ == "__main__":
132
+ try:
133
+ main()
134
+ except Exception as e:
135
+ st.error(f"Unexpected error: {e}")
136
+ logger.exception("Unhandled exception")