kleervoyans commited on
Commit
8f2c84b
·
verified ·
1 Parent(s): f4de56c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import evaluate
4
 
5
  # Page configuration
@@ -12,9 +12,9 @@ st.set_page_config(
12
  # Load model and tokenizer
13
  @st.cache_resource
14
  def load_model():
15
- model_name = "facebook/nllb-200-distilled-600M"
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
  return tokenizer, model
19
 
20
  tokenizer, model = load_model()
@@ -23,7 +23,6 @@ tokenizer, model = load_model()
23
  bleu = evaluate.load("bleu")
24
  bertscore = evaluate.load("bertscore")
25
  comet = evaluate.load("comet", module_type="metric")
26
- # For BERTurk, use Turkish BERT for BERTScore
27
  bertturk = evaluate.load("bertscore")
28
 
29
  # UI
@@ -37,7 +36,10 @@ if st.button("Translate & Evaluate"):
37
  else:
38
  # Tokenize and generate
39
  inputs = tokenizer(input_text, return_tensors="pt")
40
- outputs = model.generate(**inputs, forced_bos_token_id=tokenizer.get_lang_id("tur_TUR"))
 
 
 
41
  translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
42
 
43
  # Display translation
@@ -55,7 +57,7 @@ if st.button("Translate & Evaluate"):
55
  else:
56
  st.info("No reference provided: skipping BLEU.")
57
 
58
- # Compute BERTScore (multilingual)
59
  bs = bertscore.compute(
60
  predictions=predictions,
61
  references=[ref_text] if ref_text.strip() else [translation],
@@ -71,7 +73,7 @@ if st.button("Translate & Evaluate"):
71
  )
72
  st.metric("BERTurk (f1)", f"{bt['f1'][0]*100:.2f}")
73
 
74
- # Compute COMET if reference
75
  if references:
76
  comet_score = comet.compute(
77
  model="Unbabel/wmt22-comet-da",
 
1
  import streamlit as st
2
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
3
  import evaluate
4
 
5
  # Page configuration
 
12
  # Load model and tokenizer
13
  @st.cache_resource
14
  def load_model():
15
+ model_name = "facebook/m2m100_418M"
16
+ tokenizer = M2M100Tokenizer.from_pretrained(model_name)
17
+ model = M2M100ForConditionalGeneration.from_pretrained(model_name)
18
  return tokenizer, model
19
 
20
  tokenizer, model = load_model()
 
23
  bleu = evaluate.load("bleu")
24
  bertscore = evaluate.load("bertscore")
25
  comet = evaluate.load("comet", module_type="metric")
 
26
  bertturk = evaluate.load("bertscore")
27
 
28
  # UI
 
36
  else:
37
  # Tokenize and generate
38
  inputs = tokenizer(input_text, return_tensors="pt")
39
+ outputs = model.generate(
40
+ **inputs,
41
+ forced_bos_token_id=tokenizer.get_lang_id("tur")
42
+ )
43
  translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
44
 
45
  # Display translation
 
57
  else:
58
  st.info("No reference provided: skipping BLEU.")
59
 
60
+ # Compute BERTScore (general multilingual)
61
  bs = bertscore.compute(
62
  predictions=predictions,
63
  references=[ref_text] if ref_text.strip() else [translation],
 
73
  )
74
  st.metric("BERTurk (f1)", f"{bt['f1'][0]*100:.2f}")
75
 
76
+ # Compute COMET if reference provided
77
  if references:
78
  comet_score = comet.compute(
79
  model="Unbabel/wmt22-comet-da",