kleervoyans commited on
Commit
f4de56c
·
verified ·
1 Parent(s): 4ee9144

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import evaluate
4
+
5
+ # Page configuration
6
+ st.set_page_config(
7
+ page_title="Translation Evaluator",
8
+ layout="wide",
9
+ initial_sidebar_state="collapsed"
10
+ )
11
+
12
+ # Load model and tokenizer
13
+ @st.cache_resource
14
+ def load_model():
15
+ model_name = "facebook/nllb-200-distilled-600M"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
+ return tokenizer, model
19
+
20
+ tokenizer, model = load_model()
21
+
22
+ # Load metrics
23
+ bleu = evaluate.load("bleu")
24
+ bertscore = evaluate.load("bertscore")
25
+ comet = evaluate.load("comet", module_type="metric")
26
+ # For BERTurk, use Turkish BERT for BERTScore
27
+ bertturk = evaluate.load("bertscore")
28
+
29
+ # UI
30
+ st.title("Minimalistic Translation & Evaluation")
31
+ input_text = st.text_area("Input text (any language)", height=150)
32
+ ref_text = st.text_area("Reference translation in Turkish (optional)", height=150)
33
+
34
+ if st.button("Translate & Evaluate"):
35
+ if not input_text.strip():
36
+ st.error("Please enter some input text to translate.")
37
+ else:
38
+ # Tokenize and generate
39
+ inputs = tokenizer(input_text, return_tensors="pt")
40
+ outputs = model.generate(**inputs, forced_bos_token_id=tokenizer.get_lang_id("tur_TUR"))
41
+ translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
42
+
43
+ # Display translation
44
+ st.subheader("Model Translation (Turkish):")
45
+ st.write(translation)
46
+
47
+ # Prepare references and predictions
48
+ predictions = [translation]
49
+ references = [[ref_text]] if ref_text.strip() else None
50
+
51
+ # Compute BLEU
52
+ if references:
53
+ bleu_result = bleu.compute(predictions=predictions, references=references)
54
+ st.metric("BLEU-4", f"{bleu_result['bleu']*100:.2f}")
55
+ else:
56
+ st.info("No reference provided: skipping BLEU.")
57
+
58
+ # Compute BERTScore (multilingual)
59
+ bs = bertscore.compute(
60
+ predictions=predictions,
61
+ references=[ref_text] if ref_text.strip() else [translation],
62
+ lang="tr"
63
+ )
64
+ st.metric("BERTScore (f1)", f"{bs['f1'][0]*100:.2f}")
65
+
66
+ # Compute BERTurk specifically
67
+ bt = bertturk.compute(
68
+ predictions=predictions,
69
+ references=[ref_text] if ref_text.strip() else [translation],
70
+ model_type="dbmdz/bert-base-turkish-cased"
71
+ )
72
+ st.metric("BERTurk (f1)", f"{bt['f1'][0]*100:.2f}")
73
+
74
+ # Compute COMET if reference
75
+ if references:
76
+ comet_score = comet.compute(
77
+ model="Unbabel/wmt22-comet-da",
78
+ src=[input_text],
79
+ mt=predictions,
80
+ ref=[ref_text]
81
+ )
82
+ st.metric("COMET score", f"{comet_score['score'][0]:.2f}")
83
+ else:
84
+ st.info("No reference provided: skipping COMET.")