Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import logging
|
| 3 |
import pandas as pd
|
| 4 |
import plotly.express as px
|
|
|
|
| 5 |
from typing import Union, List
|
| 6 |
|
| 7 |
from langdetect import detect, LangDetectException
|
|
@@ -22,36 +25,44 @@ logging.basicConfig(
|
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
|
| 25 |
-
# ββββββββββ Model
|
| 26 |
class ModelManager:
|
| 27 |
"""
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
"""
|
| 31 |
-
|
| 32 |
def __init__(
|
| 33 |
self,
|
| 34 |
candidates: List[str] = None,
|
| 35 |
quantize: bool = True,
|
| 36 |
default_tgt: str = None,
|
| 37 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
self.candidates = candidates or [
|
| 39 |
"facebook/nllb-200-distilled-600M",
|
| 40 |
"facebook/m2m100_418M",
|
| 41 |
]
|
| 42 |
-
self.
|
| 43 |
-
|
|
|
|
| 44 |
self.tokenizer = None
|
| 45 |
self.model = None
|
| 46 |
self.pipeline = None
|
| 47 |
self.lang_codes: List[str] = []
|
|
|
|
| 48 |
self._select_and_load()
|
| 49 |
|
| 50 |
def _select_and_load(self):
|
| 51 |
last_err = None
|
| 52 |
for model_name in self.candidates:
|
| 53 |
try:
|
| 54 |
-
#
|
| 55 |
logger.info(f"Loading tokenizer for {model_name}")
|
| 56 |
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
| 57 |
if not hasattr(tok, "lang_code_to_id"):
|
|
@@ -59,53 +70,55 @@ class ModelManager:
|
|
| 59 |
f"Tokenizer for {model_name} missing lang_code_to_id"
|
| 60 |
)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
logger.info(
|
| 64 |
-
f"Loading model {model_name} "
|
| 65 |
-
f"(8-bit={'on' if self.quantize else 'off'})"
|
| 66 |
-
)
|
| 67 |
-
bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
|
| 68 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 69 |
-
model_name,
|
| 70 |
-
device_map="auto",
|
| 71 |
-
quantization_config=bnb_cfg,
|
| 72 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
logger.info(f"Model {model_name} loaded successfully")
|
| 74 |
|
| 75 |
-
#
|
| 76 |
pipe = pipeline(
|
| 77 |
"translation",
|
| 78 |
-
model=
|
| 79 |
tokenizer=tok,
|
| 80 |
)
|
| 81 |
|
| 82 |
-
#
|
|
|
|
| 83 |
self.tokenizer = tok
|
| 84 |
-
self.model =
|
| 85 |
self.pipeline = pipe
|
| 86 |
self.lang_codes = list(tok.lang_code_to_id.keys())
|
| 87 |
-
logger.info(f"Available language codes: {self.lang_codes[:5]}β¦")
|
| 88 |
|
| 89 |
-
#
|
| 90 |
if not self.default_tgt:
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
for code in self.lang_codes
|
| 94 |
-
if code.lower().startswith("tr")
|
| 95 |
]
|
| 96 |
-
if not
|
| 97 |
-
raise ValueError(f"No Turkish code in {model_name}")
|
| 98 |
-
self.default_tgt =
|
| 99 |
logger.info(f"Default target language: {self.default_tgt}")
|
| 100 |
|
| 101 |
return
|
| 102 |
-
|
| 103 |
except Exception as e:
|
| 104 |
logger.warning(f"Failed to load {model_name}: {e}")
|
| 105 |
last_err = e
|
| 106 |
|
| 107 |
raise RuntimeError(
|
| 108 |
-
f"Could not load any model from
|
| 109 |
)
|
| 110 |
|
| 111 |
def translate(
|
|
@@ -116,13 +129,11 @@ class ModelManager:
|
|
| 116 |
):
|
| 117 |
"""
|
| 118 |
Translate `text` from src_lang β tgt_lang.
|
| 119 |
-
|
| 120 |
-
If tgt_lang is None: use default_tgt (Turkish).
|
| 121 |
-
Returns the pipeline output (list of dicts with 'translation_text').
|
| 122 |
"""
|
| 123 |
tgt = tgt_lang or self.default_tgt
|
| 124 |
|
| 125 |
-
# Auto-detect source
|
| 126 |
if not src_lang:
|
| 127 |
sample = text[0] if isinstance(text, list) else text
|
| 128 |
try:
|
|
@@ -132,41 +143,41 @@ class ModelManager:
|
|
| 132 |
]
|
| 133 |
if not candidates:
|
| 134 |
raise LangDetectException(f"No code for ISO '{iso}'")
|
| 135 |
-
# prefer exact match
|
| 136 |
exact = [c for c in candidates if c.lower() == iso]
|
| 137 |
src = exact[0] if exact else candidates[0]
|
| 138 |
logger.info(f"Auto-detected src_lang={src}")
|
| 139 |
except Exception as e:
|
| 140 |
logger.warning(f"langdetect failed ({e}); defaulting to English")
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
else:
|
| 144 |
src = src_lang
|
| 145 |
|
| 146 |
-
# Call the pipeline with both src_lang and tgt_lang
|
| 147 |
return self.pipeline(text, src_lang=src, tgt_lang=tgt)
|
| 148 |
|
| 149 |
def get_info(self):
|
| 150 |
-
"""Return metadata for sidebar display."""
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
return {
|
| 156 |
-
"model":
|
| 157 |
-
"quantized":
|
| 158 |
-
"device":
|
| 159 |
"default_tgt": self.default_tgt,
|
| 160 |
}
|
| 161 |
|
| 162 |
|
| 163 |
-
# ββββββββββ
|
| 164 |
class TranslationEvaluator:
|
| 165 |
def __init__(self):
|
| 166 |
self.bleu = evaluate.load("bleu")
|
| 167 |
self.bertscore = evaluate.load("bertscore")
|
| 168 |
self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
|
| 169 |
-
|
| 170 |
|
| 171 |
def evaluate(
|
| 172 |
self,
|
|
@@ -175,36 +186,27 @@ class TranslationEvaluator:
|
|
| 175 |
predictions: List[str],
|
| 176 |
):
|
| 177 |
results = {}
|
| 178 |
-
|
| 179 |
# BLEU
|
| 180 |
results["BLEU"] = self.bleu.compute(
|
| 181 |
predictions=predictions,
|
| 182 |
references=[[r] for r in references],
|
| 183 |
)["bleu"]
|
| 184 |
-
|
| 185 |
# BERTScore (general)
|
| 186 |
bs = self.bertscore.compute(
|
| 187 |
predictions=predictions, references=references, lang="xx"
|
| 188 |
)
|
| 189 |
results["BERTScore"] = sum(bs["f1"]) / len(bs["f1"]) if bs["f1"] else 0.0
|
| 190 |
-
|
| 191 |
# BERTurk (Turkish)
|
| 192 |
bs_tr = self.bertscore.compute(
|
| 193 |
predictions=predictions, references=references, lang="tr"
|
| 194 |
)
|
| 195 |
results["BERTurk"] = sum(bs_tr["f1"]) / len(bs_tr["f1"]) if bs_tr["f1"] else 0.0
|
| 196 |
-
|
| 197 |
# COMET
|
| 198 |
-
|
| 199 |
srcs=sources, hyps=predictions, refs=references
|
| 200 |
)
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
if isinstance(score, list):
|
| 204 |
-
results["COMET"] = score[0] if score else 0.0
|
| 205 |
-
else:
|
| 206 |
-
results["COMET"] = score or 0.0
|
| 207 |
-
|
| 208 |
return results
|
| 209 |
|
| 210 |
|
|
@@ -212,9 +214,6 @@ class TranslationEvaluator:
|
|
| 212 |
|
| 213 |
@st.cache_resource
|
| 214 |
def load_resources():
|
| 215 |
-
"""
|
| 216 |
-
Load and cache ModelManager & TranslationEvaluator on first run.
|
| 217 |
-
"""
|
| 218 |
mgr = ModelManager(quantize=True)
|
| 219 |
ev = TranslationEvaluator()
|
| 220 |
return mgr, ev
|
|
@@ -235,7 +234,7 @@ def process_text(
|
|
| 235 |
ev: TranslationEvaluator,
|
| 236 |
metrics: List[str],
|
| 237 |
):
|
| 238 |
-
out = mgr.translate(src)
|
| 239 |
hyp = out[0]["translation_text"]
|
| 240 |
scores = ev.evaluate([src], [ref or ""], [hyp])
|
| 241 |
return {
|
|
@@ -258,7 +257,7 @@ def _show_single_results(res: dict):
|
|
| 258 |
st.write(res["reference"])
|
| 259 |
with right:
|
| 260 |
st.markdown("### Scores")
|
| 261 |
-
df = pd.DataFrame([{k: v for k, v in res.items() if k in
|
| 262 |
st.table(df)
|
| 263 |
|
| 264 |
|
|
@@ -279,7 +278,7 @@ def process_file(
|
|
| 279 |
batch = df.iloc[i : i + batch_size]
|
| 280 |
srcs = batch["src"].tolist()
|
| 281 |
refs = batch["ref_tr"].tolist()
|
| 282 |
-
outs = mgr.translate(srcs)
|
| 283 |
hyps = [o["translation_text"] for o in outs]
|
| 284 |
for s, r, h in zip(srcs, refs, hyps):
|
| 285 |
sc = ev.evaluate([s], [r], [h])
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
import logging
|
| 5 |
import pandas as pd
|
| 6 |
import plotly.express as px
|
| 7 |
+
import torch
|
| 8 |
from typing import Union, List
|
| 9 |
|
| 10 |
from langdetect import detect, LangDetectException
|
|
|
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
| 27 |
|
| 28 |
+
# ββββββββββ Model Manager ββββββββββ
|
| 29 |
class ModelManager:
|
| 30 |
"""
|
| 31 |
+
Selects and loads a translation model (NLLB-200 or M2M100),
|
| 32 |
+
using 8-bit quantization only if CUDA is available.
|
| 33 |
+
Auto-detects source language and defaults target to Turkish.
|
| 34 |
"""
|
|
|
|
| 35 |
def __init__(
|
| 36 |
self,
|
| 37 |
candidates: List[str] = None,
|
| 38 |
quantize: bool = True,
|
| 39 |
default_tgt: str = None,
|
| 40 |
):
|
| 41 |
+
# If user requested quantization but CUDA isn't available, disable it
|
| 42 |
+
if quantize and not torch.cuda.is_available():
|
| 43 |
+
logger.warning("CUDA unavailable; disabling 8-bit quantization")
|
| 44 |
+
quantize = False
|
| 45 |
+
self.quantize = quantize
|
| 46 |
+
|
| 47 |
self.candidates = candidates or [
|
| 48 |
"facebook/nllb-200-distilled-600M",
|
| 49 |
"facebook/m2m100_418M",
|
| 50 |
]
|
| 51 |
+
self.default_tgt = default_tgt # will auto-pick if None
|
| 52 |
+
|
| 53 |
+
self.selected_model_name: str = None
|
| 54 |
self.tokenizer = None
|
| 55 |
self.model = None
|
| 56 |
self.pipeline = None
|
| 57 |
self.lang_codes: List[str] = []
|
| 58 |
+
|
| 59 |
self._select_and_load()
|
| 60 |
|
| 61 |
def _select_and_load(self):
|
| 62 |
last_err = None
|
| 63 |
for model_name in self.candidates:
|
| 64 |
try:
|
| 65 |
+
# Load tokenizer
|
| 66 |
logger.info(f"Loading tokenizer for {model_name}")
|
| 67 |
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
| 68 |
if not hasattr(tok, "lang_code_to_id"):
|
|
|
|
| 70 |
f"Tokenizer for {model_name} missing lang_code_to_id"
|
| 71 |
)
|
| 72 |
|
| 73 |
+
# Load model (with or without 8-bit)
|
| 74 |
logger.info(
|
| 75 |
+
f"Loading model {model_name} (8-bit={self.quantize})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
)
|
| 77 |
+
if self.quantize:
|
| 78 |
+
bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
|
| 79 |
+
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
| 80 |
+
model_name,
|
| 81 |
+
device_map="auto",
|
| 82 |
+
quantization_config=bnb_cfg,
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
| 86 |
+
model_name,
|
| 87 |
+
device_map="auto",
|
| 88 |
+
)
|
| 89 |
logger.info(f"Model {model_name} loaded successfully")
|
| 90 |
|
| 91 |
+
# Wrap in a translation pipeline
|
| 92 |
pipe = pipeline(
|
| 93 |
"translation",
|
| 94 |
+
model=mdl,
|
| 95 |
tokenizer=tok,
|
| 96 |
)
|
| 97 |
|
| 98 |
+
# Store and break
|
| 99 |
+
self.selected_model_name = model_name
|
| 100 |
self.tokenizer = tok
|
| 101 |
+
self.model = mdl
|
| 102 |
self.pipeline = pipe
|
| 103 |
self.lang_codes = list(tok.lang_code_to_id.keys())
|
|
|
|
| 104 |
|
| 105 |
+
# Auto-pick Turkish target code if none specified
|
| 106 |
if not self.default_tgt:
|
| 107 |
+
tur_codes = [
|
| 108 |
+
c for c in self.lang_codes if c.lower().startswith("tr")
|
|
|
|
|
|
|
| 109 |
]
|
| 110 |
+
if not tur_codes:
|
| 111 |
+
raise ValueError(f"No Turkish code found in {model_name}")
|
| 112 |
+
self.default_tgt = tur_codes[0]
|
| 113 |
logger.info(f"Default target language: {self.default_tgt}")
|
| 114 |
|
| 115 |
return
|
|
|
|
| 116 |
except Exception as e:
|
| 117 |
logger.warning(f"Failed to load {model_name}: {e}")
|
| 118 |
last_err = e
|
| 119 |
|
| 120 |
raise RuntimeError(
|
| 121 |
+
f"Could not load any model from {self.candidates}: {last_err}"
|
| 122 |
)
|
| 123 |
|
| 124 |
def translate(
|
|
|
|
| 129 |
):
|
| 130 |
"""
|
| 131 |
Translate `text` from src_lang β tgt_lang.
|
| 132 |
+
Auto-detects src_lang if not given.
|
|
|
|
|
|
|
| 133 |
"""
|
| 134 |
tgt = tgt_lang or self.default_tgt
|
| 135 |
|
| 136 |
+
# Auto-detect source language if missing
|
| 137 |
if not src_lang:
|
| 138 |
sample = text[0] if isinstance(text, list) else text
|
| 139 |
try:
|
|
|
|
| 143 |
]
|
| 144 |
if not candidates:
|
| 145 |
raise LangDetectException(f"No code for ISO '{iso}'")
|
|
|
|
| 146 |
exact = [c for c in candidates if c.lower() == iso]
|
| 147 |
src = exact[0] if exact else candidates[0]
|
| 148 |
logger.info(f"Auto-detected src_lang={src}")
|
| 149 |
except Exception as e:
|
| 150 |
logger.warning(f"langdetect failed ({e}); defaulting to English")
|
| 151 |
+
eng_codes = [
|
| 152 |
+
c for c in self.lang_codes if c.lower().startswith("en")
|
| 153 |
+
]
|
| 154 |
+
src = eng_codes[0] if eng_codes else self.lang_codes[0]
|
| 155 |
else:
|
| 156 |
src = src_lang
|
| 157 |
|
|
|
|
| 158 |
return self.pipeline(text, src_lang=src, tgt_lang=tgt)
|
| 159 |
|
| 160 |
def get_info(self):
|
| 161 |
+
"""Return metadata for the sidebar display."""
|
| 162 |
+
device = "cpu"
|
| 163 |
+
if torch.cuda.is_available() and hasattr(self.model, "device"):
|
| 164 |
+
idx = self.model.device.index if hasattr(self.model.device, "index") else None
|
| 165 |
+
device = f"cuda:{idx}" if idx is not None else "cuda"
|
| 166 |
return {
|
| 167 |
+
"model": self.selected_model_name,
|
| 168 |
+
"quantized": self.quantize,
|
| 169 |
+
"device": device,
|
| 170 |
"default_tgt": self.default_tgt,
|
| 171 |
}
|
| 172 |
|
| 173 |
|
| 174 |
+
# ββββββββββ Evaluator ββββββββββ
|
| 175 |
class TranslationEvaluator:
|
| 176 |
def __init__(self):
|
| 177 |
self.bleu = evaluate.load("bleu")
|
| 178 |
self.bertscore = evaluate.load("bertscore")
|
| 179 |
self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
|
| 180 |
+
logger.info("Loaded BLEU, BERTScore, COMET metrics")
|
| 181 |
|
| 182 |
def evaluate(
|
| 183 |
self,
|
|
|
|
| 186 |
predictions: List[str],
|
| 187 |
):
|
| 188 |
results = {}
|
|
|
|
| 189 |
# BLEU
|
| 190 |
results["BLEU"] = self.bleu.compute(
|
| 191 |
predictions=predictions,
|
| 192 |
references=[[r] for r in references],
|
| 193 |
)["bleu"]
|
|
|
|
| 194 |
# BERTScore (general)
|
| 195 |
bs = self.bertscore.compute(
|
| 196 |
predictions=predictions, references=references, lang="xx"
|
| 197 |
)
|
| 198 |
results["BERTScore"] = sum(bs["f1"]) / len(bs["f1"]) if bs["f1"] else 0.0
|
|
|
|
| 199 |
# BERTurk (Turkish)
|
| 200 |
bs_tr = self.bertscore.compute(
|
| 201 |
predictions=predictions, references=references, lang="tr"
|
| 202 |
)
|
| 203 |
results["BERTurk"] = sum(bs_tr["f1"]) / len(bs_tr["f1"]) if bs_tr["f1"] else 0.0
|
|
|
|
| 204 |
# COMET
|
| 205 |
+
cm = self.comet.compute(
|
| 206 |
srcs=sources, hyps=predictions, refs=references
|
| 207 |
)
|
| 208 |
+
scores = cm.get("scores", None)
|
| 209 |
+
results["COMET"] = float(scores[0] if isinstance(scores, list) else scores) or 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
return results
|
| 211 |
|
| 212 |
|
|
|
|
| 214 |
|
| 215 |
@st.cache_resource
|
| 216 |
def load_resources():
|
|
|
|
|
|
|
|
|
|
| 217 |
mgr = ModelManager(quantize=True)
|
| 218 |
ev = TranslationEvaluator()
|
| 219 |
return mgr, ev
|
|
|
|
| 234 |
ev: TranslationEvaluator,
|
| 235 |
metrics: List[str],
|
| 236 |
):
|
| 237 |
+
out = mgr.translate(src)
|
| 238 |
hyp = out[0]["translation_text"]
|
| 239 |
scores = ev.evaluate([src], [ref or ""], [hyp])
|
| 240 |
return {
|
|
|
|
| 257 |
st.write(res["reference"])
|
| 258 |
with right:
|
| 259 |
st.markdown("### Scores")
|
| 260 |
+
df = pd.DataFrame([{k: v for k, v in res.items() if k in metrics}])
|
| 261 |
st.table(df)
|
| 262 |
|
| 263 |
|
|
|
|
| 278 |
batch = df.iloc[i : i + batch_size]
|
| 279 |
srcs = batch["src"].tolist()
|
| 280 |
refs = batch["ref_tr"].tolist()
|
| 281 |
+
outs = mgr.translate(srcs)
|
| 282 |
hyps = [o["translation_text"] for o in outs]
|
| 283 |
for s, r, h in zip(srcs, refs, hyps):
|
| 284 |
sc = ev.evaluate([s], [r], [h])
|