Gaoussin commited on
Commit
dc0f0ce
·
verified ·
1 Parent(s): 993c4e8

Updated model and minimize code.

Browse files
Files changed (1) hide show
  1. main.py +1 -32
main.py CHANGED
@@ -18,37 +18,9 @@ if HF_TOKEN is None:
18
  # 3️⃣ DEVICE
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # 4️⃣ Load model + tokenizer (PRIVATE REPO)
22
- #model_name = "Gaoussin/bamalingua-bm-fr"
23
- #tokenizer = MBart50TokenizerFast.from_pretrained(model_name, token=HF_TOKEN)
24
- #model = MBartForConditionalGeneration.from_pretrained(model_name, token=HF_TOKEN).to(device)
25
- ####
26
- # 3. Load tokenizer & add Bambara token
27
- # ========================================
28
  model_name = "Gaoussin/bamalingua-bm_ml-fr_XX"
29
- # Load the tokenizer with a default language and suppress the error
30
- try:
31
- tokenizer = MBart50Tokenizer.from_pretrained(model_name, src_lang="fr_XX")
32
- except KeyError:
33
- # If loading with en_XX fails, try without specifying src_lang and fix afterwards
34
- tokenizer = MBart50Tokenizer.from_pretrained(model_name)
35
-
36
- # Add the new language as an additional special token and update mappings
37
- new_lang = 'bm_ml'
38
- if new_lang not in tokenizer.lang_code_to_id:
39
- tokenizer.add_special_tokens({'additional_special_tokens': [new_lang]})
40
- # Update the internal language code mappings
41
- new_id = len(tokenizer) - 1
42
- tokenizer.lang_code_to_id[new_lang] = new_id
43
- tokenizer.id_to_lang_code[new_id] = new_lang
44
- print(f"Added new language token '{new_lang}' with ID {new_id}")
45
- else:
46
- print(f"Language token '{new_lang}' already exists in tokenizer.")
47
-
48
- # Load model
49
  model = MBartForConditionalGeneration.from_pretrained("Gaoussin/bamalingua-bm_ml-fr_XX")
50
- model.resize_token_embeddings(len(tokenizer))
51
-
52
  #####
53
 
54
 
@@ -71,9 +43,6 @@ class TranslationRequest(BaseModel):
71
  @app.post("/translate")
72
  def translate(request: TranslationRequest):
73
  output = translateTo(request.text, request.src_lang, request.tgt_lang)
74
- # Remove the unwanted token if it's present
75
- if "fr_XX" in output:
76
- output = output.replace("fr_XX", "").strip()
77
  return {"translation": output}
78
 
79
  @app.get("/")
 
18
  # 3️⃣ DEVICE
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
 
 
 
 
 
 
 
21
  model_name = "Gaoussin/bamalingua-bm_ml-fr_XX"
22
+ tokenizer = MBart50Tokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  model = MBartForConditionalGeneration.from_pretrained("Gaoussin/bamalingua-bm_ml-fr_XX")
 
 
24
  #####
25
 
26
 
 
43
  @app.post("/translate")
44
  def translate(request: TranslationRequest):
45
  output = translateTo(request.text, request.src_lang, request.tgt_lang)
 
 
 
46
  return {"translation": output}
47
 
48
  @app.get("/")