# app.py import streamlit as st import streamlit.components.v1 as components import logging import torch import random import numpy as np import pandas as pd import plotly.express as px import time import difflib from typing import List, Union from langdetect import detect, LangDetectException from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, BitsAndBytesConfig, ) import evaluate from sacrebleu import corpus_bleu, sentence_bleu # ────────── Page Config (MUST be first) ────────── st.set_page_config(page_title="🔤 Translate→Eval+", layout="wide") # ────────── Global CSS ────────── st.markdown(""" """, unsafe_allow_html=True) # ────────── Logging ────────── logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) # ────────── Utilities ────────── def bootstrap( fn, predictions: List[str], references: List[str], sources: List[str]=None, n_resamples: int = 200, seed: int = 42 ) -> List[float]: random.seed(seed) scores = [] N = len(predictions) for _ in range(n_resamples): idxs = [random.randrange(N) for _ in range(N)] ps = [predictions[i] for i in idxs] rs = [references[i] for i in idxs] if sources: ss = [sources[i] for i in idxs] scores.append(fn(ps, rs, ss)) else: scores.append(fn(ps, rs)) return scores # ────────── Model Manager ────────── class ModelManager: def __init__(self, candidates=None, quantize=True, default_tgt=None): if quantize and not torch.cuda.is_available(): logger.warning("CUDA unavailable; disabling 8-bit quantization") quantize = False self.quantize = quantize self.candidates = candidates or [ "facebook/nllb-200-distilled-600M", "facebook/m2m100_418M", ] self.default_tgt = default_tgt self._load_best() def _load_best(self): last_err = None for name in self.candidates: try: tok = AutoTokenizer.from_pretrained(name, use_fast=True) if not hasattr(tok, "lang_code_to_id"): raise AttributeError("no lang_code_to_id") logger.info(f"Loading {name} (8-bit={self.quantize})") if self.quantize: bnb = BitsAndBytesConfig(load_in_8bit=True) mdl = AutoModelForSeq2SeqLM.from_pretrained( name, device_map="auto", quantization_config=bnb ) else: mdl = AutoModelForSeq2SeqLM.from_pretrained(name, device_map="auto") pipe = pipeline("translation", model=mdl, tokenizer=tok) self.model_name = name self.tokenizer = tok self.model = mdl self.pipeline = pipe self.lang_codes = list(tok.lang_code_to_id.keys()) if not self.default_tgt: tur = [c for c in self.lang_codes if c.lower().startswith("tr")] if not tur: raise ValueError("No Turkish code found") self.default_tgt = tur[0] logger.info(f"default_tgt = {self.default_tgt}") return except Exception as e: logger.warning(f"Failed to load {name}: {e}") last_err = e raise RuntimeError(f"No model loaded: {last_err}") def translate(self, text: Union[str, List[str]], src_lang: str=None, tgt_lang: str=None): tgt = tgt_lang or self.default_tgt if not src_lang: sample = text[0] if isinstance(text, list) else text try: iso = detect(sample).lower() cand = [c for c in self.lang_codes if c.lower().startswith(iso)] if not cand: raise LangDetectException() exact = [c for c in cand if c.lower()==iso] src = exact[0] if exact else cand[0] logger.info(f"Detected src_lang={src}") except Exception: eng = [c for c in self.lang_codes if c.lower().startswith("en")] src = eng[0] if eng else self.lang_codes[0] logger.warning(f"Falling back src_lang={src}") else: src = src_lang return self.pipeline(text, src_lang=src, tgt_lang=tgt) def get_info(self): dev = "cpu" if torch.cuda.is_available() and hasattr(self.model, "device"): d = self.model.device dev = str(d) if isinstance(d, torch.device) else f"cuda:{getattr(d,'index','')}" return { "model": self.model_name, "quantized": self.quantize, "device": dev, "default_tgt": self.default_tgt, "langs": self.lang_codes, } # ────────── Evaluator ────────── class TranslationEvaluator: def __init__(self): self.bleu = evaluate.load("bleu") self.chrf = evaluate.load("chrf") self.ter = evaluate.load("ter") self.bertscore = evaluate.load("bertscore") self.comet_ref = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da") self.comet_qe = evaluate.load("comet", model_id="unbabel/wmt20-comet-qe-da") logger.info("Loaded BLEU, ChrF, TER, BERTScore, COMET") def compute_metrics(self, srcs, refs, hyps, metrics, ci=True): out = {} if "BLEU_doc" in metrics: out["BLEU_doc"] = float(self.bleu.compute(predictions=hyps, references=[[r] for r in refs])["bleu"]) if "BLEU_seg" in metrics: segs = [sentence_bleu([r], p).score for p,r in zip(hyps, refs)] out["BLEU_seg"] = float(sum(segs)/len(segs)) if "ChrF" in metrics: out["ChrF"] = float(self.chrf.compute(predictions=hyps, references=[[r] for r in refs])["score"]) if "TER" in metrics: out["TER"] = float(self.ter.compute(predictions=hyps, references=[[r] for r in refs], normalized=True)["score"]) if "BERTScore" in metrics: bs = self.bertscore.compute(predictions=hyps, references=refs, lang="xx")["f1"] out["BERTScore"] = float(sum(bs)/len(bs)) if bs else 0.0 if "BERTurk" in metrics: bt = self.bertscore.compute(predictions=hyps, references=refs, lang="tr")["f1"] out["BERTurk"] = float(sum(bt)/len(bt)) if bt else 0.0 if "COMET" in metrics: sc = self.comet_ref.compute(srcs=srcs, hyps=hyps, refs=refs).get("scores",0.0) out["COMET"] = float(sc[0] if isinstance(sc,list) else sc) if "QE" in metrics: q = self.comet_qe.compute(srcs=srcs, hyps=hyps).get("scores",0.0) out["QE"] = float(q[0] if isinstance(q,list) else q) if ci: if "CI_BLEU_doc" in metrics: bsamp = bootstrap(lambda ps,rs: self.bleu.compute(predictions=ps,references=[[r] for r in rs])["bleu"], hyps, refs) out["CI_BLEU_doc"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5))) if "CI_BERTScore" in metrics: bsamp = bootstrap(lambda ps,rs: sum(self.bertscore.compute(predictions=ps,references=rs,lang="xx")["f1"])/len(ps), hyps, refs) out["CI_BERTScore"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5))) if "CI_COMET" in metrics: bsamp = bootstrap(lambda ps,rs,ss: float(self.comet_ref.compute(srcs=ss,hyps=ps,refs=rs).get("scores",[0.0])[0]), hyps, refs, srcs) out["CI_COMET"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5))) return out # ────────── Error Categorizer ────────── class ErrorCategorizer: def __init__(self, model_name=None): self.pipe = pipeline("text-classification", model=model_name, device=0 if torch.cuda.is_available() else -1) if model_name else None def categorize(self, src, hyp): if not self.pipe: return [] inp = f"SRC: {src}\nHYP: {hyp}\nError types:" return self.pipe(inp, top_k=None) # ────────── Streamlit App ────────── @st.cache_resource def load_resources(): mgr = ModelManager(quantize=True) ev = TranslationEvaluator() err = ErrorCategorizer(model_name=None) # set your HF model here return mgr, ev, err def display_model_info(info: dict): st.sidebar.markdown("### Model Info") st.sidebar.write(f"• **Model:** {info['model']}") st.sidebar.write(f"• **Quantized:** {info['quantized']}") st.sidebar.write(f"• **Device:** {info['device']}") st.sidebar.write(f"• **Default tgt:** {info['default_tgt']}") def show_diff(ref, hyp): differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=60) html = differ.make_table(ref.split(), hyp.split(), fromdesc="Reference", todesc="Hypothesis", context=True, numlines=1) components.html(html, height=200, scrolling=True) def main(): # Note: set_page_config has been moved to the top! st.title("🌐 Translate → Evaluate & Analyze") st.write("Translate any language, choose target, eval with advanced metrics, and inspect errors.") with st.sidebar: st.header("Settings") mgr, ev, err = load_resources() info = mgr.get_info() display_model_info(info) tgt = st.selectbox("Target language", info["langs"], index=info["langs"].index(info["default_tgt"])) metric_opts = ["BLEU_doc","BLEU_seg","ChrF","TER","BERTScore","BERTurk","COMET","QE","CI_BLEU_doc","CI_BERTScore","CI_COMET"] metrics = st.multiselect("Metrics & CIs", metric_opts, default=["BLEU_doc","BERTScore","COMET"]) batch_size = st.slider("Batch size", 1, 32, 8) tab1, tab2 = st.tabs(["Single","Batch CSV"]) with tab1: src = st.text_area("Source text:", height=120) ref = st.text_area("Gold reference (optional):", height=80) if st.button("Translate & Eval"): with st.spinner("⏳ Translating…"): out = mgr.translate(src, tgt_lang=tgt) hyp = out[0]["translation_text"] st.markdown(f"**Hypothesis ({tgt}):** {hyp}") scores = ev.compute_metrics([src],[ref or ""],[hyp], metrics) sd = {} for m in metrics: v = scores.get(m) if m.startswith("CI_") and v: sd[m] = f"{v[0]:.3f} – {v[1]:.3f}" else: sd[m] = f"{v:.4f}" if v is not None else "N/A" st.markdown("### Scores") st.table(pd.DataFrame([sd])) if ref.strip(): st.markdown("### Diff View") show_diff(ref, hyp) cats = err.categorize(src, hyp) if cats: st.markdown("### Error Categories") st.json(cats) with tab2: uploaded = st.file_uploader("Upload CSV with `src`,`ref_tr`", type=["csv"]) if uploaded: df = pd.read_csv(uploaded) if not {"src","ref_tr"}.issubset(df.columns): st.error("CSV must have `src` and `ref_tr` columns.") else: with st.spinner("⏳ Batch processing…"): all_rows = [] prog = st.progress(0) N = len(df) for i in range(0, N, batch_size): batch = df.iloc[i:i+batch_size] srcs, refs = batch["src"].tolist(), batch["ref_tr"].tolist() outs = mgr.translate(srcs, tgt_lang=tgt) hyps = [o["translation_text"] for o in outs] for s,r,h in zip(srcs,refs,hyps): base = {"src":s,"ref_tr":r,"hyp_tr":h} if r.strip(): sc = ev.compute_metrics([s],[r],[h], metrics) for m in metrics: if m.startswith("CI_") and sc.get(m): low, high = sc[m] base[m] = f"{low:.3f}–{high:.3f}" else: base[m] = sc.get(m) else: for m in metrics: base[m] = None all_rows.append(base) prog.progress(min(i+batch_size,N)/N) res_df = pd.DataFrame(all_rows) st.markdown("### Results") st.dataframe(res_df, use_container_width=True) for m in metrics: st.markdown(f"#### {m} Distribution") col = pd.to_numeric(res_df[m], errors="coerce").dropna() if col.empty: st.write("No valid data.") else: fig = px.histogram(col, x=col) st.plotly_chart(fig, use_container_width=True) st.download_button("Download CSV", res_df.to_csv(index=False), "results.csv") if __name__=="__main__": main()