Gaoussin commited on
Commit
8d78e99
·
verified ·
1 Parent(s): bd57210

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +53 -63
main.py CHANGED
@@ -1,88 +1,78 @@
1
  import os
2
- # 2️⃣ Optional: force cache to writable directory
 
 
 
 
 
3
  os.environ["HF_HOME"] = "/tmp/hf"
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
5
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
6
  os.makedirs("/tmp/hf", exist_ok=True)
7
 
 
 
 
 
8
 
9
- from fastapi import FastAPI
10
- from pydantic import BaseModel
11
- #from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
 
13
- #from replacer import replace_words, replace_dict
14
- #from datasets import Dataset
15
- from transformers import MBartForConditionalGeneration, MBart50Tokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # -------------------------
18
- # 1️⃣ Get your HF token from Space Secrets
19
- # In your Space, go to Settings → Secrets → add HF_TOKEN
20
- #HF_TOKEN = os.environ.get("HF_TOKEN")
21
- #if HF_TOKEN is None:
22
- # raise ValueError("HF_TOKEN not found. Please add it in your Space Secrets.")
 
 
 
 
 
23
 
24
- # -------------------------
 
 
25
 
26
- # -------------------------
27
- # 3️⃣ Load private model
28
- model_name = "Gaoussin/bamalingua-bm_ml-fr_XX"
29
- model = MBartForConditionalGeneration.from_pretrained(model_name)
30
- tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50")
31
  #####
32
- def fix_tokenizer(tokenizer, new_lang='bm_ml'):
33
- """
34
- Add a new language token to the tokenizer vocabulary
35
- (this should be done each time after its initialization)
36
- """
37
- # Check if the language token already exists
38
- if new_lang not in tokenizer.lang_code_to_id:
39
- # Add the new language as an additional special token
40
- tokenizer.add_special_tokens({'additional_special_tokens': [new_lang]})
41
- # Update the internal language code mappings
42
- # Note: This is a workaround as MBart50Tokenizer doesn't have a direct way to add lang codes
43
- # The new token will be added at the end of the vocabulary
44
- new_id = len(tokenizer) - 1
45
- tokenizer.lang_code_to_id[new_lang] = new_id
46
- tokenizer.id_to_lang_code[new_id] = new_lang
47
- print(f"Added new language token '{new_lang}' with ID {new_id}")
48
- else:
49
- print(f"Language token '{new_lang}' already exists in tokenizer.")
50
-
51
 
52
- fix_tokenizer(tokenizer, new_lang='bm_ml')
53
 
54
- model.resize_token_embeddings(len(tokenizer))
55
- print("model resized")
56
- ######
57
- tgt_lang = "bm_ml"
 
 
 
58
 
59
- # -------------------------
60
- # 4️⃣ FastAPI app
61
  app = FastAPI()
62
 
63
  class TranslationRequest(BaseModel):
64
  text: str
 
 
65
 
66
  @app.post("/translate")
67
  def translate(request: TranslationRequest):
68
- #reverse_dict = {v: k for k, v in replace_dict.items()}
69
- #text_for_ai = replace_words(request.text, reverse_dict)
70
-
71
- inputs = tokenizer(
72
- request.text,
73
- return_tensors="pt",
74
- max_length=128,
75
- truncation=True)
76
 
77
- outputs = model.generate(
78
- **inputs,
79
- forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
80
- text2 = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
81
-
82
- #text_for_user = replace_words(text2, replace_dict)
83
-
84
- return {"translation": text2[0].upper() + text2[1:]}
85
-
86
  @app.get("/")
87
  def root():
88
- return {"message": "API is running"}
 
1
  import os
2
+ import torch
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
6
+
7
+ # 1️⃣ Cache (optional)
8
  os.environ["HF_HOME"] = "/tmp/hf"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
10
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
11
  os.makedirs("/tmp/hf", exist_ok=True)
12
 
13
+ # 2️⃣ HF TOKEN
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+ if HF_TOKEN is None:
16
+ raise ValueError("HF_TOKEN not found. Please add it in your Space Secrets.")
17
 
18
+ # 3️⃣ DEVICE
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
20
 
21
+ # 4️⃣ Load model + tokenizer (PRIVATE REPO)
22
+ #model_name = "Gaoussin/bamalingua-bm-fr"
23
+ #tokenizer = MBart50TokenizerFast.from_pretrained(model_name, token=HF_TOKEN)
24
+ #model = MBartForConditionalGeneration.from_pretrained(model_name, token=HF_TOKEN).to(device)
25
+ ####
26
+ # 3. Load tokenizer & add Bambara token
27
+ # ========================================
28
+ model_name = "my_tokenizer"
29
+ # Load the tokenizer with a default language and suppress the error
30
+ try:
31
+ tokenizer = MBart50Tokenizer.from_pretrained(model_name, src_lang="en_XX")
32
+ except KeyError:
33
+ # If loading with en_XX fails, try without specifying src_lang and fix afterwards
34
+ tokenizer = MBart50Tokenizer.from_pretrained(model_name)
35
 
36
+ # Add the new language as an additional special token and update mappings
37
+ new_lang = 'bm_ml'
38
+ if new_lang not in tokenizer.lang_code_to_id:
39
+ tokenizer.add_special_tokens({'additional_special_tokens': [new_lang]})
40
+ # Update the internal language code mappings
41
+ new_id = len(tokenizer) - 1
42
+ tokenizer.lang_code_to_id[new_lang] = new_id
43
+ tokenizer.id_to_lang_code[new_id] = new_lang
44
+ print(f"Added new language token '{new_lang}' with ID {new_id}")
45
+ else:
46
+ print(f"Language token '{new_lang}' already exists in tokenizer.")
47
 
48
+ # Load model
49
+ model = MBartForConditionalGeneration.from_pretrained("Gaoussin/bamalingua-bm_ml-fr_XX")
50
+ model.resize_token_embeddings(len(tokenizer))
51
 
 
 
 
 
 
52
  #####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
54
 
55
+ # 5️⃣ Translation function
56
+ def translateTo(text, src_lang, tgt_lang):
57
+ tokenizer.src_lang = src_lang
58
+ inputs = tokenizer(text, return_tensors="pt").to(device)
59
+ tgt_id = tokenizer.lang_code_to_id[tgt_lang]
60
+ generated = model.generate(**inputs, forced_bos_token_id=tgt_id)
61
+ return tokenizer.decode(generated[0], skip_special_tokens=True)
62
 
63
+ # 6️⃣ FastAPI
 
64
  app = FastAPI()
65
 
66
  class TranslationRequest(BaseModel):
67
  text: str
68
+ src_lang: str
69
+ tgt_lang: str
70
 
71
  @app.post("/translate")
72
  def translate(request: TranslationRequest):
73
+ output = translateTo(request.text, request.src_lang, request.tgt_lang)
74
+ return {"translation": output}
 
 
 
 
 
 
75
 
 
 
 
 
 
 
 
 
 
76
  @app.get("/")
77
  def root():
78
+ return {"message": "API is running"}