junaid17 commited on
Commit
45ca0d4
·
verified ·
1 Parent(s): 626ae0d

Upload backend.py

Browse files
Files changed (1) hide show
  1. backend.py +53 -0
backend.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from peft import PeftModel
4
+
5
+ BASE = "facebook/nllb-200-distilled-600M"
6
+ LORA = "junaid17/nllb-kurdish-lora"
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(BASE)
9
+
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ _model = None
14
+
15
+ def load_model():
16
+ global _model
17
+ if _model is None:
18
+ try:
19
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE)
20
+ _model = PeftModel.from_pretrained(base_model, LORA).eval()
21
+ print("Model loaded succesfully...")
22
+ except Exception as e:
23
+ print(f"Error while loading the model : {str(e)}")
24
+ return _model.to(device)
25
+
26
+ #model = load_model()
27
+
28
+ def translate(src_lang, tgt_lang, model, text):
29
+ try:
30
+ encoded = tokenizer(
31
+ text,
32
+ return_tensors="pt",
33
+ padding=True,
34
+ truncation=True
35
+ ).to(device)
36
+
37
+ forced_bos = tokenizer.convert_tokens_to_ids(tgt_lang)
38
+
39
+ output_tokens = model.generate(
40
+ **encoded,
41
+ forced_bos_token_id=forced_bos,
42
+ max_length=256,
43
+ num_beams=4
44
+ )
45
+
46
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
47
+ except Exception as e:
48
+ print(f"Could't translate due to unexpected error : {str(e)}")
49
+
50
+
51
+
52
+ #text = "hello, my name is junaid"
53
+ #print(translate(src_lang='eng_Latn', tgt_lang='ckb_Arab', model=model, text=text))