Spaces:
Sleeping
Sleeping
| # app.py | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import logging | |
| import torch | |
| import random | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.express as px | |
| import time | |
| import difflib | |
| from typing import List, Union | |
| from langdetect import detect, LangDetectException | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| pipeline, | |
| BitsAndBytesConfig, | |
| ) | |
| import evaluate | |
| from sacrebleu import corpus_bleu, sentence_bleu | |
| # ββββββββββ Page Config (MUST be first) ββββββββββ | |
| st.set_page_config(page_title="π€ TranslateβEval+", layout="wide") | |
| # ββββββββββ Global CSS ββββββββββ | |
| st.markdown(""" | |
| <style> | |
| .main .block-container { max-width: 900px; padding: 1rem 2rem; } | |
| .stButton>button { background-color: #4A90E2; color: white; border-radius: 4px; } | |
| .stButton>button:hover { background-color: #357ABD; } | |
| textarea { border-radius: 4px; } | |
| .stTable table { border-radius: 4px; overflow: hidden; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ββββββββββ Logging ββββββββββ | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ββββββββββ Utilities ββββββββββ | |
| def bootstrap( | |
| fn, predictions: List[str], references: List[str], sources: List[str]=None, | |
| n_resamples: int = 200, seed: int = 42 | |
| ) -> List[float]: | |
| random.seed(seed) | |
| scores = [] | |
| N = len(predictions) | |
| for _ in range(n_resamples): | |
| idxs = [random.randrange(N) for _ in range(N)] | |
| ps = [predictions[i] for i in idxs] | |
| rs = [references[i] for i in idxs] | |
| if sources: | |
| ss = [sources[i] for i in idxs] | |
| scores.append(fn(ps, rs, ss)) | |
| else: | |
| scores.append(fn(ps, rs)) | |
| return scores | |
| # ββββββββββ Model Manager ββββββββββ | |
| class ModelManager: | |
| def __init__(self, candidates=None, quantize=True, default_tgt=None): | |
| if quantize and not torch.cuda.is_available(): | |
| logger.warning("CUDA unavailable; disabling 8-bit quantization") | |
| quantize = False | |
| self.quantize = quantize | |
| self.candidates = candidates or [ | |
| "facebook/nllb-200-distilled-600M", | |
| "facebook/m2m100_418M", | |
| ] | |
| self.default_tgt = default_tgt | |
| self._load_best() | |
| def _load_best(self): | |
| last_err = None | |
| for name in self.candidates: | |
| try: | |
| tok = AutoTokenizer.from_pretrained(name, use_fast=True) | |
| if not hasattr(tok, "lang_code_to_id"): | |
| raise AttributeError("no lang_code_to_id") | |
| logger.info(f"Loading {name} (8-bit={self.quantize})") | |
| if self.quantize: | |
| bnb = BitsAndBytesConfig(load_in_8bit=True) | |
| mdl = AutoModelForSeq2SeqLM.from_pretrained( | |
| name, device_map="auto", quantization_config=bnb | |
| ) | |
| else: | |
| mdl = AutoModelForSeq2SeqLM.from_pretrained(name, device_map="auto") | |
| pipe = pipeline("translation", model=mdl, tokenizer=tok) | |
| self.model_name = name | |
| self.tokenizer = tok | |
| self.model = mdl | |
| self.pipeline = pipe | |
| self.lang_codes = list(tok.lang_code_to_id.keys()) | |
| if not self.default_tgt: | |
| tur = [c for c in self.lang_codes if c.lower().startswith("tr")] | |
| if not tur: | |
| raise ValueError("No Turkish code found") | |
| self.default_tgt = tur[0] | |
| logger.info(f"default_tgt = {self.default_tgt}") | |
| return | |
| except Exception as e: | |
| logger.warning(f"Failed to load {name}: {e}") | |
| last_err = e | |
| raise RuntimeError(f"No model loaded: {last_err}") | |
| def translate(self, text: Union[str, List[str]], src_lang: str=None, tgt_lang: str=None): | |
| tgt = tgt_lang or self.default_tgt | |
| if not src_lang: | |
| sample = text[0] if isinstance(text, list) else text | |
| try: | |
| iso = detect(sample).lower() | |
| cand = [c for c in self.lang_codes if c.lower().startswith(iso)] | |
| if not cand: raise LangDetectException() | |
| exact = [c for c in cand if c.lower()==iso] | |
| src = exact[0] if exact else cand[0] | |
| logger.info(f"Detected src_lang={src}") | |
| except Exception: | |
| eng = [c for c in self.lang_codes if c.lower().startswith("en")] | |
| src = eng[0] if eng else self.lang_codes[0] | |
| logger.warning(f"Falling back src_lang={src}") | |
| else: | |
| src = src_lang | |
| return self.pipeline(text, src_lang=src, tgt_lang=tgt) | |
| def get_info(self): | |
| dev = "cpu" | |
| if torch.cuda.is_available() and hasattr(self.model, "device"): | |
| d = self.model.device | |
| dev = str(d) if isinstance(d, torch.device) else f"cuda:{getattr(d,'index','')}" | |
| return { | |
| "model": self.model_name, | |
| "quantized": self.quantize, | |
| "device": dev, | |
| "default_tgt": self.default_tgt, | |
| "langs": self.lang_codes, | |
| } | |
| # ββββββββββ Evaluator ββββββββββ | |
| class TranslationEvaluator: | |
| def __init__(self): | |
| self.bleu = evaluate.load("bleu") | |
| self.chrf = evaluate.load("chrf") | |
| self.ter = evaluate.load("ter") | |
| self.bertscore = evaluate.load("bertscore") | |
| self.comet_ref = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da") | |
| self.comet_qe = evaluate.load("comet", model_id="unbabel/wmt20-comet-qe-da") | |
| logger.info("Loaded BLEU, ChrF, TER, BERTScore, COMET") | |
| def compute_metrics(self, srcs, refs, hyps, metrics, ci=True): | |
| out = {} | |
| if "BLEU_doc" in metrics: | |
| out["BLEU_doc"] = float(self.bleu.compute(predictions=hyps, references=[[r] for r in refs])["bleu"]) | |
| if "BLEU_seg" in metrics: | |
| segs = [sentence_bleu([r], p).score for p,r in zip(hyps, refs)] | |
| out["BLEU_seg"] = float(sum(segs)/len(segs)) | |
| if "ChrF" in metrics: | |
| out["ChrF"] = float(self.chrf.compute(predictions=hyps, references=[[r] for r in refs])["score"]) | |
| if "TER" in metrics: | |
| out["TER"] = float(self.ter.compute(predictions=hyps, references=[[r] for r in refs], normalized=True)["score"]) | |
| if "BERTScore" in metrics: | |
| bs = self.bertscore.compute(predictions=hyps, references=refs, lang="xx")["f1"] | |
| out["BERTScore"] = float(sum(bs)/len(bs)) if bs else 0.0 | |
| if "BERTurk" in metrics: | |
| bt = self.bertscore.compute(predictions=hyps, references=refs, lang="tr")["f1"] | |
| out["BERTurk"] = float(sum(bt)/len(bt)) if bt else 0.0 | |
| if "COMET" in metrics: | |
| sc = self.comet_ref.compute(srcs=srcs, hyps=hyps, refs=refs).get("scores",0.0) | |
| out["COMET"] = float(sc[0] if isinstance(sc,list) else sc) | |
| if "QE" in metrics: | |
| q = self.comet_qe.compute(srcs=srcs, hyps=hyps).get("scores",0.0) | |
| out["QE"] = float(q[0] if isinstance(q,list) else q) | |
| if ci: | |
| if "CI_BLEU_doc" in metrics: | |
| bsamp = bootstrap(lambda ps,rs: self.bleu.compute(predictions=ps,references=[[r] for r in rs])["bleu"], hyps, refs) | |
| out["CI_BLEU_doc"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5))) | |
| if "CI_BERTScore" in metrics: | |
| bsamp = bootstrap(lambda ps,rs: sum(self.bertscore.compute(predictions=ps,references=rs,lang="xx")["f1"])/len(ps), hyps, refs) | |
| out["CI_BERTScore"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5))) | |
| if "CI_COMET" in metrics: | |
| bsamp = bootstrap(lambda ps,rs,ss: float(self.comet_ref.compute(srcs=ss,hyps=ps,refs=rs).get("scores",[0.0])[0]), hyps, refs, srcs) | |
| out["CI_COMET"] = (float(np.percentile(bsamp,2.5)), float(np.percentile(bsamp,97.5))) | |
| return out | |
| # ββββββββββ Error Categorizer ββββββββββ | |
| class ErrorCategorizer: | |
| def __init__(self, model_name=None): | |
| self.pipe = pipeline("text-classification", model=model_name, device=0 if torch.cuda.is_available() else -1) if model_name else None | |
| def categorize(self, src, hyp): | |
| if not self.pipe: return [] | |
| inp = f"SRC: {src}\nHYP: {hyp}\nError types:" | |
| return self.pipe(inp, top_k=None) | |
| # ββββββββββ Streamlit App ββββββββββ | |
| def load_resources(): | |
| mgr = ModelManager(quantize=True) | |
| ev = TranslationEvaluator() | |
| err = ErrorCategorizer(model_name=None) # set your HF model here | |
| return mgr, ev, err | |
| def display_model_info(info: dict): | |
| st.sidebar.markdown("### Model Info") | |
| st.sidebar.write(f"β’ **Model:** {info['model']}") | |
| st.sidebar.write(f"β’ **Quantized:** {info['quantized']}") | |
| st.sidebar.write(f"β’ **Device:** {info['device']}") | |
| st.sidebar.write(f"β’ **Default tgt:** {info['default_tgt']}") | |
| def show_diff(ref, hyp): | |
| differ = difflib.HtmlDiff(tabsize=4, wrapcolumn=60) | |
| html = differ.make_table(ref.split(), hyp.split(), fromdesc="Reference", todesc="Hypothesis", context=True, numlines=1) | |
| components.html(html, height=200, scrolling=True) | |
| def main(): | |
| # Note: set_page_config has been moved to the top! | |
| st.title("π Translate β Evaluate & Analyze") | |
| st.write("Translate any language, choose target, eval with advanced metrics, and inspect errors.") | |
| with st.sidebar: | |
| st.header("Settings") | |
| mgr, ev, err = load_resources() | |
| info = mgr.get_info() | |
| display_model_info(info) | |
| tgt = st.selectbox("Target language", info["langs"], index=info["langs"].index(info["default_tgt"])) | |
| metric_opts = ["BLEU_doc","BLEU_seg","ChrF","TER","BERTScore","BERTurk","COMET","QE","CI_BLEU_doc","CI_BERTScore","CI_COMET"] | |
| metrics = st.multiselect("Metrics & CIs", metric_opts, default=["BLEU_doc","BERTScore","COMET"]) | |
| batch_size = st.slider("Batch size", 1, 32, 8) | |
| tab1, tab2 = st.tabs(["Single","Batch CSV"]) | |
| with tab1: | |
| src = st.text_area("Source text:", height=120) | |
| ref = st.text_area("Gold reference (optional):", height=80) | |
| if st.button("Translate & Eval"): | |
| with st.spinner("β³ Translatingβ¦"): | |
| out = mgr.translate(src, tgt_lang=tgt) | |
| hyp = out[0]["translation_text"] | |
| st.markdown(f"**Hypothesis ({tgt}):** {hyp}") | |
| scores = ev.compute_metrics([src],[ref or ""],[hyp], metrics) | |
| sd = {} | |
| for m in metrics: | |
| v = scores.get(m) | |
| if m.startswith("CI_") and v: | |
| sd[m] = f"{v[0]:.3f} β {v[1]:.3f}" | |
| else: | |
| sd[m] = f"{v:.4f}" if v is not None else "N/A" | |
| st.markdown("### Scores") | |
| st.table(pd.DataFrame([sd])) | |
| if ref.strip(): | |
| st.markdown("### Diff View") | |
| show_diff(ref, hyp) | |
| cats = err.categorize(src, hyp) | |
| if cats: | |
| st.markdown("### Error Categories") | |
| st.json(cats) | |
| with tab2: | |
| uploaded = st.file_uploader("Upload CSV with `src`,`ref_tr`", type=["csv"]) | |
| if uploaded: | |
| df = pd.read_csv(uploaded) | |
| if not {"src","ref_tr"}.issubset(df.columns): | |
| st.error("CSV must have `src` and `ref_tr` columns.") | |
| else: | |
| with st.spinner("β³ Batch processingβ¦"): | |
| all_rows = [] | |
| prog = st.progress(0) | |
| N = len(df) | |
| for i in range(0, N, batch_size): | |
| batch = df.iloc[i:i+batch_size] | |
| srcs, refs = batch["src"].tolist(), batch["ref_tr"].tolist() | |
| outs = mgr.translate(srcs, tgt_lang=tgt) | |
| hyps = [o["translation_text"] for o in outs] | |
| for s,r,h in zip(srcs,refs,hyps): | |
| base = {"src":s,"ref_tr":r,"hyp_tr":h} | |
| if r.strip(): | |
| sc = ev.compute_metrics([s],[r],[h], metrics) | |
| for m in metrics: | |
| if m.startswith("CI_") and sc.get(m): | |
| low, high = sc[m] | |
| base[m] = f"{low:.3f}β{high:.3f}" | |
| else: | |
| base[m] = sc.get(m) | |
| else: | |
| for m in metrics: | |
| base[m] = None | |
| all_rows.append(base) | |
| prog.progress(min(i+batch_size,N)/N) | |
| res_df = pd.DataFrame(all_rows) | |
| st.markdown("### Results") | |
| st.dataframe(res_df, use_container_width=True) | |
| for m in metrics: | |
| st.markdown(f"#### {m} Distribution") | |
| col = pd.to_numeric(res_df[m], errors="coerce").dropna() | |
| if col.empty: | |
| st.write("No valid data.") | |
| else: | |
| fig = px.histogram(col, x=col) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.download_button("Download CSV", res_df.to_csv(index=False), "results.csv") | |
| if __name__=="__main__": | |
| main() | |