Gaoussin commited on
Commit
12d2d84
·
verified ·
1 Parent(s): 11c4660

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -10
main.py CHANGED
@@ -1,22 +1,23 @@
1
  import os
 
 
 
 
 
2
  # 2️⃣ Optional: force cache to writable directory
3
  os.environ["HF_HOME"] = "/tmp/hf"
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
5
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
6
  os.makedirs("/tmp/hf", exist_ok=True)
7
 
8
-
9
- from fastapi import FastAPI
10
- from pydantic import BaseModel
11
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
- import torch
13
-
14
  app = FastAPI()
15
 
16
  # Load model once on startup
17
- MODEL_NAME = "facebook/nllb-200-1.3B" # 3B version
18
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,use_fast=False)
19
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to("cuda" if torch.cuda.is_available() else "cpu")
 
 
20
 
21
  class TranslationRequest(BaseModel):
22
  text: str
@@ -30,9 +31,12 @@ def translate(req: TranslationRequest):
30
  return_tensors="pt",
31
  ).to(model.device)
32
 
 
 
 
33
  outputs = model.generate(
34
  **inputs,
35
- forced_bos_token_id=tokenizer.lang_code_to_id[req.tgt_lang],
36
  max_length=512
37
  )
38
  translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
1
  import os
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import torch
6
+
7
  # 2️⃣ Optional: force cache to writable directory
8
  os.environ["HF_HOME"] = "/tmp/hf"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
10
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
11
  os.makedirs("/tmp/hf", exist_ok=True)
12
 
 
 
 
 
 
 
13
  app = FastAPI()
14
 
15
  # Load model once on startup
16
+ MODEL_NAME = "facebook/nllb-200-1.3B"
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(
19
+ MODEL_NAME, torch_dtype=torch.float16
20
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
  class TranslationRequest(BaseModel):
23
  text: str
 
31
  return_tensors="pt",
32
  ).to(model.device)
33
 
34
+ # ✅ add "__" around lang codes
35
+ tgt_lang = "__" + req.tgt_lang + "__"
36
+
37
  outputs = model.generate(
38
  **inputs,
39
+ forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
40
  max_length=512
41
  )
42
  translation = tokenizer.decode(outputs[0], skip_special_tokens=True)