Gaoussin commited on
Commit
9464e73
·
verified ·
1 Parent(s): 85b5b85

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -63
main.py CHANGED
@@ -1,86 +1,63 @@
1
  import os
2
-
3
- # cache dirs for HF
4
  os.environ["HF_HOME"] = "/tmp/hf"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
6
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
7
  os.makedirs("/tmp/hf", exist_ok=True)
8
 
9
 
10
- from fastapi import FastAPI, Request, HTTPException
11
  from pydantic import BaseModel
12
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
- import torch
14
- import secrets
15
-
16
 
17
- # -----------------------------
18
- # 1️⃣ Generate / load API key
19
- # -----------------------------
20
- # Generate a new key once (uncomment to create)
21
- # print(secrets.token_hex(32))
22
 
23
- # Or load from environment variable
24
- API_KEY = os.getenv("MY_API_KEY", "ec1a464f3948d7e9e0484efad4f71d0a0aa9f3fb37560697c42da0568b9fbac5")
 
 
 
 
25
 
26
- # -----------------------------
27
- # 2️⃣ Initialize FastAPI & model
28
- # -----------------------------
29
- app = FastAPI()
30
 
 
 
 
 
 
31
 
32
- MODEL_NAME = "facebook/nllb-200-3.3B"
33
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
34
- model = AutoModelForSeq2SeqLM.from_pretrained(
35
- MODEL_NAME, torch_dtype=torch.float16
36
- ).to("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
- # 🔑 Build lang_code_to_id manually
39
- lang_code_to_id = {
40
- tok.strip("_"): idx
41
- for tok, idx in zip(tokenizer.additional_special_tokens,
42
- tokenizer.additional_special_tokens_ids)
43
- }
44
 
45
  class TranslationRequest(BaseModel):
46
  text: str
47
- src_lang: str
48
- tgt_lang: str
49
-
50
-
51
 
52
  @app.post("/translate")
53
- async def translate(req: TranslationRequest, request: Request):
54
- # ---- Check API key ----
55
- api_key = request.headers.get("X-API-KEY")
56
- if api_key != API_KEY:
57
- raise HTTPException(status_code=403, detail="Unauthorized API key")
58
-
59
- # ---- Optional IP restriction ----
60
- #client_ip = request.client.host
61
- #if ALLOWED_IPS and client_ip not in ALLOWED_IPS:
62
- # raise HTTPException(status_code=403, detail="IP not allowed")
63
-
64
-
65
- # always set source language
66
- tokenizer.src_lang = req.src_lang # 👈 force the source language
67
 
68
  inputs = tokenizer(
69
- req.text,
70
- return_tensors="pt",
71
- ).to(model.device)
72
-
73
- # force target language
74
- tgt_lang = req.tgt_lang
75
- forced_bos_id = lang_code_to_id[tgt_lang]
76
 
77
  outputs = model.generate(
78
- **inputs,
79
- forced_bos_token_id=forced_bos_id,
80
- max_length=512#,
81
- #num_beams=5, # good for quality
82
- #early_stopping=True
83
- )
84
-
85
- translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
86
- return {"translation": translation}
 
 
 
1
  import os
2
+ # 2️⃣ Optional: force cache to writable directory
 
3
  os.environ["HF_HOME"] = "/tmp/hf"
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
5
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
6
  os.makedirs("/tmp/hf", exist_ok=True)
7
 
8
 
9
+ from fastapi import FastAPI
10
  from pydantic import BaseModel
11
+ #from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
12
 
13
+ #from replacer import replace_words, replace_dict
14
+ #from datasets import Dataset
15
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
 
 
16
 
17
+ # -------------------------
18
+ # 1️⃣ Get your HF token from Space Secrets
19
+ # In your Space, go to Settings → Secrets → add HF_TOKEN
20
+ #HF_TOKEN = os.environ.get("HF_TOKEN")
21
+ #if HF_TOKEN is None:
22
+ # raise ValueError("HF_TOKEN not found. Please add it in your Space Secrets.")
23
 
24
+ # -------------------------
 
 
 
25
 
26
+ # -------------------------
27
+ # 3️⃣ Load private model
28
+ model_name = "Gaoussin/bamalingua-bm_ml-fr_XX"
29
+ model = MBartForConditionalGeneration.from_pretrained(model_name)
30
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
31
 
32
+ tgt_lang = "bm_ml"
 
 
 
 
33
 
34
+ # -------------------------
35
+ # 4️⃣ FastAPI app
36
+ app = FastAPI()
 
 
 
37
 
38
  class TranslationRequest(BaseModel):
39
  text: str
 
 
 
 
40
 
41
  @app.post("/translate")
42
+ def translate(request: TranslationRequest):
43
+ #reverse_dict = {v: k for k, v in replace_dict.items()}
44
+ #text_for_ai = replace_words(request.text, reverse_dict)
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  inputs = tokenizer(
47
+ request.text,
48
+ return_tensors="pt",
49
+ max_length=128,
50
+ truncation=True)
 
 
 
51
 
52
  outputs = model.generate(
53
+ **inputs,
54
+ forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
55
+ text2 = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
56
+
57
+ #text_for_user = replace_words(text2, replace_dict)
58
+
59
+ return {"translation": text2[0].upper() + text2[1:]}
60
+
61
+ @app.get("/")
62
+ def root():
63
+ return {"message": "API is running"}