NetherQuartz commited on
Commit
fb9637d
Β·
verified Β·
1 Parent(s): 5561a52

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +10 -9
src/streamlit_app.py CHANGED
@@ -3,12 +3,17 @@ import streamlit as st
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
- MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged"
7
- DEVICE = "mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
8
 
9
 
10
  @st.cache_resource
11
- def get_model():
12
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
14
  return model, tokenizer
@@ -17,17 +22,13 @@ def get_model():
17
  model, tokenizer = get_model()
18
 
19
 
20
- def translate(src_lang, tgt_lang, query):
21
  text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:"
22
  tokens = tokenizer(text, return_tensors="pt").to(DEVICE)
23
  outputs = model.generate(**tokens)
24
  ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
- return ans.removeprefix(text)
26
 
27
- st.set_page_config(
28
- page_icon="πŸ’¬",
29
- page_title="ilo toki"
30
- )
31
 
32
  st.title("πŸ’¬ ilo toki")
33
 
 
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
+ MODEL_PATH = "NetherQuartz/tatoeba-tok-multi-gemma-2-2b-merged-int4"
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ st.set_page_config(
10
+ page_icon="πŸ’¬",
11
+ page_title="ilo toki"
12
+ )
13
 
14
 
15
  @st.cache_resource
16
+ def get_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
17
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
19
  return model, tokenizer
 
22
  model, tokenizer = get_model()
23
 
24
 
25
+ def translate(src_lang: str, tgt_lang: str, query: str) -> str:
26
  text = f"Translate {src_lang} to {tgt_lang}.\nQuery: {query}\nAnswer:"
27
  tokens = tokenizer(text, return_tensors="pt").to(DEVICE)
28
  outputs = model.generate(**tokens)
29
  ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+ return ans.removeprefix(text).strip()
31
 
 
 
 
 
32
 
33
  st.title("πŸ’¬ ilo toki")
34