evaluator / app.py
kleervoyans's picture
Update app.py
99f56e7 verified
# app.py
import streamlit as st
import streamlit.components.v1 as components
import logging
import torch
import random
import numpy as np
import pandas as pd
import plotly.express as px
import time
import difflib
from typing import List, Union
from langdetect import detect, LangDetectException
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline,
BitsAndBytesConfig,
)
import evaluate
from sacrebleu import corpus_bleu, sentence_bleu
# ────────── Page Config (MUST be first) ──────────
st.set_page_config(page_title="πŸ”€ Translateβ†’Eval+", layout="wide")
# ────────── Global CSS ──────────
st.markdown("""
<style>
.main .block-container { max-width: 900px; padding: 1rem 2rem; }
.stButton>button { background-color: #4A90E2; color: white; border-radius: 4px; }
.stButton>button:hover { background-color: #357ABD; }
textarea { border-radius: 4px; }
.stTable table { border-radius: 4px; overflow: hidden; }
</style>
""", unsafe_allow_html=True)
# ────────── Logging ──────────
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
# ────────── Utilities ──────────
def bootstrap(
fn, predictions: List[str], references: List[str], sources: List[str]=None,
n_resamples: int = 200, seed: int = 42
) -> List[float]:
random.seed(seed)
scores = []
N = len(predictions)
for _ in range(n_resamples):
idxs = [random.randrange(N) for _ in range(N)]
ps = [predictions[i] for i in idxs]
rs = [references[i] for i in idxs]
if sources:
ss = [sources[i] for i in idxs]
scores.append(fn(ps, rs, ss))
else:
scores.append(fn(ps, rs))
return scores
# ────────── Model Manager ──────────
class ModelManager:
def __init__(self, candidates=None, quantize=True, default_tgt=None):
if quantize and not torch.cuda.is_available():
logger.warning("CUDA unavailable; disabling 8-bit quantization")
quantize = False
self.quantize = quantize
self.candidates = candidates or [
"facebook/nllb-200-distilled-600M",
"facebook/m2m100_418M",
]
self.default_tgt = default_tgt
self._load_best()
def _load_best(self):
last_err = None
for name in self.candidates:
try:
tok = AutoTokenizer.from_pretrained(name, use_fast=True)
if not hasattr(tok, "lang_code_to_id"):
raise AttributeError("no lang_code_to_id")
logger.info(f"Loading {name} (8-bit={self.quantize})")
if self.quantize:
bnb = BitsAndBytesConfig(load_in_8bit=True)
mdl = AutoModelForSeq2SeqLM.from_pretrained(
name, device_map="auto", quantization_config=bnb
)
else:
mdl = AutoModelForSeq2SeqLM.from_pretrained(name, device_map="auto")
pipe = pipeline("translation", model=mdl, tokenizer=tok)
self.model_name = name
self.tokenizer = tok
self.model = mdl
self.pipeline = pipe
self.lang_codes = list(tok.lang_code_to_id.keys())
if not self.default_tgt:
tur = [c for c in self.lang_codes if c.lower().startswith("tr")]
if not tur:
raise ValueError("No Turkish code found")
self.default_tgt = tur[0]
logger.info(f"default_tgt = {self.default_tgt}")
return
except Exception as e:
logger.warning(f"Failed to load {name}: {e}")
last_err = e
raise RuntimeError(f"No model loaded: {last_err}")
def translate(self, text: Union[str, List[str]], src_lang: str=None, tgt_lang: str=None):
tgt = tgt_lang or self.default_tgt
if not src_lang:
sample = text[0] if isinstance(text, list) else text
try:
iso = detect(sample).lower()
cand = [c for c in self.lang_codes if c.lower().startswith(iso)]
if not cand: raise LangDetectException()
exact = [c for c in cand if c.lower()==iso]
src = exact[0] if exact else cand[0]
logger.info(f"Detected src_lang={src}")
except Exception:
eng = [c for c in self.lang_codes if c.lower().startswith("en")]
src = eng[0] if eng else self.lang_codes[0]
logger.warning(f"Falling back src_lang={src}")
else:
src = src_lang
return self.pipeline(text, src_lang=src, tgt_lang=tgt)
def get_info(self):
dev = "cpu"
if torch.cuda.is_available() and hasattr(self.model, "device"):
d = self.model.device
dev = str(d) if isinstance(d, torch.device) else f"cuda:{getattr(d,'index','')}"
return {
"model": self.model_name,
"quantized": self.quantize,
"device": dev,
"default_tgt": self.default_tgt,
"langs": self.lang_codes,
}
# ────────── Evaluator ──────────
class TranslationEvaluator:
def __init__(self):
self.bleu = evaluate.load("bleu")
self.chrf = evaluate.load("chrf")
self.ter = evaluate.load("ter")
self.bertscore = evaluate.load("bertscore")
self.comet_ref = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
self.comet_qe = evaluate.load("comet", model_id="unbabel/wmt20-comet-qe-da")
logger.info("Loaded BLEU, ChrF, TER, BERTScore, COMET")
def compute_metrics(self, srcs, refs, hyps, metrics, ci=True):
out = {}
if "BLEU_doc" in metrics:
out["BLEU_doc"] = float(self.bleu.compute(predictions=hyps, references=[[r] for r in refs])["bleu"])
if "BLEU_seg" in metrics:
segs = [sentence_bleu([r], p).score for p,r in zip(hyps, refs)]
out["BLEU_seg"] = float(sum(segs)/len(segs))
if "ChrF" in metrics:
out["ChrF"] = float(self.chrf.compute(predictions=hyps, references=[[r] for r in refs])["score"])
if "TER" in metrics:
out["TER"] = float(self.ter.compute(predictions=hyps, references=[[r] for r in refs], normalized=True)["score"])
if "BERTScore" in metrics:
bs = self.bertscore.compute(predictions=hyps, references=refs, lang="xx")["f1"]
out["BERTScore"] = float(sum(bs)/len(bs)) if bs else 0.0
if "BERTurk" in metrics:
bt = self.bertscore.compute(predictions=hyps, references=refs, lang="tr")["f1"]
out["BERTurk"] = float(sum(bt)/len(bt)) if bt else 0.0
if "COMET" in metrics:
sc = self.comet_ref.compute(srcs=srcs, hyps=hyps, refs=refs).get("scores",0.0)
out["COMET"] = float(sc[0] if isinstance(sc,list) else sc)
if "QE" in metrics:
q = self.comet_qe.compute(srcs=srcs, hyps=hyps).get("scores",0.0)
out["QE"] = float(q[0] if isinstance(q,list) else q)
if ci:
if "CI_BLEU_doc" in metrics:
bsamp = bootstrap(lambda ps,rs: self.bleu.compute(predictions=ps,references=[[r] for r in rs])["bleu"], hyps, refs)
out["CI_BLEU_doc"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5)))
if "CI_BERTScore" in metrics:
bsamp = bootstrap(lambda ps,rs: sum(self.bertscore.compute(predictions=ps,references=rs,lang="xx")["f1"])/len(ps), hyps, refs)
out["CI_BERTScore"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5)))
if "CI_COMET" in metrics:
bsamp = bootstrap(lambda ps,rs,ss: float(self.comet_ref.compute(srcs=ss,hyps=ps,refs=rs).get("scores",[0.0])[0]), hyps, refs, srcs)
out["CI_COMET"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5)))
return out
# ────────── Error Categorizer ──────────
class ErrorCategorizer:
def __init__(self, model_name=None):
self.pipe = pipeline("text-classification", model=model_name, device=0 if torch.cuda.is_available() else -1) if model_name else None
def categorize(self, src, hyp):
if not self.pipe: return []
inp = f"SRC: {src}\nHYP: {hyp}\nError types:"
return self.pipe(inp, top_k=None)
# ────────── Streamlit App ──────────
@st.cache_resource
def load_resources():
mgr = ModelManager(quantize=True)
ev = TranslationEvaluator()
err = ErrorCategorizer(model_name=None) # set your HF model here
return mgr, ev, err
def display_model_info(info: dict):
st.sidebar.markdown("### Model Info")
st.sidebar.write(f"β€’ **Model:** {info['model']}")
st.sidebar.write(f"β€’ **Quantized:** {info['quantized']}")
st.sidebar.write(f"β€’ **Device:** {info['device']}")
st.sidebar.write(f"β€’ **Default tgt:** {info['default_tgt']}")
def show_diff(ref, hyp):
differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=60)
html = differ.make_table(ref.split(), hyp.split(), fromdesc="Reference", todesc="Hypothesis", context=True, numlines=1)
components.html(html, height=200, scrolling=True)
def main():
# Note: set_page_config has been moved to the top!
st.title("🌐 Translate β†’ Evaluate & Analyze")
st.write("Translate any language, choose target, eval with advanced metrics, and inspect errors.")
with st.sidebar:
st.header("Settings")
mgr, ev, err = load_resources()
info = mgr.get_info()
display_model_info(info)
tgt = st.selectbox("Target language", info["langs"], index=info["langs"].index(info["default_tgt"]))
metric_opts = ["BLEU_doc","BLEU_seg","ChrF","TER","BERTScore","BERTurk","COMET","QE","CI_BLEU_doc","CI_BERTScore","CI_COMET"]
metrics = st.multiselect("Metrics & CIs", metric_opts, default=["BLEU_doc","BERTScore","COMET"])
batch_size = st.slider("Batch size", 1, 32, 8)
tab1, tab2 = st.tabs(["Single","Batch CSV"])
with tab1:
src = st.text_area("Source text:", height=120)
ref = st.text_area("Gold reference (optional):", height=80)
if st.button("Translate & Eval"):
with st.spinner("⏳ Translating…"):
out = mgr.translate(src, tgt_lang=tgt)
hyp = out[0]["translation_text"]
st.markdown(f"**Hypothesis ({tgt}):** {hyp}")
scores = ev.compute_metrics([src],[ref or ""],[hyp], metrics)
sd = {}
for m in metrics:
v = scores.get(m)
if m.startswith("CI_") and v:
sd[m] = f"{v[0]:.3f} – {v[1]:.3f}"
else:
sd[m] = f"{v:.4f}" if v is not None else "N/A"
st.markdown("### Scores")
st.table(pd.DataFrame([sd]))
if ref.strip():
st.markdown("### Diff View")
show_diff(ref, hyp)
cats = err.categorize(src, hyp)
if cats:
st.markdown("### Error Categories")
st.json(cats)
with tab2:
uploaded = st.file_uploader("Upload CSV with `src`,`ref_tr`", type=["csv"])
if uploaded:
df = pd.read_csv(uploaded)
if not {"src","ref_tr"}.issubset(df.columns):
st.error("CSV must have `src` and `ref_tr` columns.")
else:
with st.spinner("⏳ Batch processing…"):
all_rows = []
prog = st.progress(0)
N = len(df)
for i in range(0, N, batch_size):
batch = df.iloc[i:i+batch_size]
srcs, refs = batch["src"].tolist(), batch["ref_tr"].tolist()
outs = mgr.translate(srcs, tgt_lang=tgt)
hyps = [o["translation_text"] for o in outs]
for s,r,h in zip(srcs,refs,hyps):
base = {"src":s,"ref_tr":r,"hyp_tr":h}
if r.strip():
sc = ev.compute_metrics([s],[r],[h], metrics)
for m in metrics:
if m.startswith("CI_") and sc.get(m):
low, high = sc[m]
base[m] = f"{low:.3f}–{high:.3f}"
else:
base[m] = sc.get(m)
else:
for m in metrics:
base[m] = None
all_rows.append(base)
prog.progress(min(i+batch_size,N)/N)
res_df = pd.DataFrame(all_rows)
st.markdown("### Results")
st.dataframe(res_df, use_container_width=True)
for m in metrics:
st.markdown(f"#### {m} Distribution")
col = pd.to_numeric(res_df[m], errors="coerce").dropna()
if col.empty:
st.write("No valid data.")
else:
fig = px.histogram(col, x=col)
st.plotly_chart(fig, use_container_width=True)
st.download_button("Download CSV", res_df.to_csv(index=False), "results.csv")
if __name__=="__main__":
main()