Gaoussin commited on
Commit
af87f0e
·
verified ·
1 Parent(s): dc0f0ce

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +77 -29
main.py CHANGED
@@ -1,50 +1,98 @@
1
  import os
2
  import torch
3
- from fastapi import FastAPI
4
  from pydantic import BaseModel
5
- from transformers import MBartForConditionalGeneration, MBart50Tokenizer
 
6
 
7
- # 1️⃣ Cache (optional)
 
 
 
8
  os.environ["HF_HOME"] = "/tmp/hf"
9
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
10
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
11
  os.makedirs("/tmp/hf", exist_ok=True)
12
 
13
- # 2️⃣ HF TOKEN
14
- HF_TOKEN = os.environ.get("mySpace")
15
- if HF_TOKEN is None:
16
- raise ValueError("HF_TOKEN not found. Please add it in your Space Secrets.")
17
-
18
- # 3️⃣ DEVICE
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
20
 
21
- model_name = "Gaoussin/bamalingua-bm_ml-fr_XX"
22
- tokenizer = MBart50Tokenizer.from_pretrained(model_name)
23
- model = MBartForConditionalGeneration.from_pretrained("Gaoussin/bamalingua-bm_ml-fr_XX")
24
- #####
25
-
26
-
27
- # 5️⃣ Translation function
28
- def translateTo(text, src_lang, tgt_lang):
29
- tokenizer.src_lang = src_lang
30
- inputs = tokenizer(text, return_tensors="pt").to(device)
31
- tgt_id = tokenizer.lang_code_to_id[tgt_lang]
32
- generated = model.generate(**inputs, forced_bos_token_id=tgt_id)
33
- return tokenizer.decode(generated[0], skip_special_tokens=True)
34
 
35
- # 6️⃣ FastAPI
 
 
36
  app = FastAPI()
37
 
 
38
  class TranslationRequest(BaseModel):
39
  text: str
40
- src_lang: str
41
- tgt_lang: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- @app.post("/translate")
 
 
 
44
  def translate(request: TranslationRequest):
45
- output = translateTo(request.text, request.src_lang, request.tgt_lang)
46
- return {"translation": output}
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  @app.get("/")
49
  def root():
50
- return {"message": "API is running "}
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
+ # Note: Keep the imports together for clarity
6
+ from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
7
 
8
+ # =====================
9
+ # 1️⃣ Environment / Cache
10
+ # =====================
11
+ # Setting cache environment variables for Hugging Face
12
  os.environ["HF_HOME"] = "/tmp/hf"
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
14
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
15
  os.makedirs("/tmp/hf", exist_ok=True)
16
 
17
+ # =====================
18
+ # 2️⃣ Device
19
+ # =====================
 
 
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ print(f"Using device: {device}")
22
 
23
+ # =====================
24
+ # 3️⃣ Load Model & Tokenizer
25
+ # =====================
26
+ # Charger le modèle et le tokenizer NLLB
27
+ try:
28
+ model_name = "Gaoussin/bamalingua-2"
29
+ tokenizer = NllbTokenizer.from_pretrained(model_name)
30
+ # Move model to the selected device (CPU or GPU)
31
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
32
+ print(f"Model '{model_name}' loaded successfully on {device}.")
33
+ except Exception as e:
34
+ print(f"Error loading model or tokenizer: {e}")
35
+ # In a real application, you might exit or handle this more gracefully
36
 
37
+ # =====================
38
+ # 4️⃣ FastAPI setup - Define Input and Output Schemas
39
+ # =====================
40
  app = FastAPI()
41
 
42
+ # Input schema
43
  class TranslationRequest(BaseModel):
44
  text: str
45
+ src_lang: str # e.g., "bam_Latn"
46
+ tgt_lang: str # e.g., "fra_Latn"
47
+
48
+ # Output schema (THE FIX: ensures both fields are returned)
49
+ class TranslationResponse(BaseModel):
50
+ """
51
+ Ensures both the translated text and the app version ID are included
52
+ in the response JSON.
53
+ """
54
+ translation: str
55
+ appVersionId: str
56
+
57
+ # =====================
58
+ # 5️⃣ Translation function - Restored to user's original logic
59
+ # =====================
60
+ def translateTo(text, src, tgt):
61
+ tokenizer.src_lang = src
62
+ tokenizer.tgt_lang = tgt
63
+ print(tokenizer.src_lang, tokenizer.tgt_lang)
64
+
65
+ # Prepare input for the model
66
+ # We explicitly move the inputs to the same device as the model
67
+ inputs = tokenizer(text, return_tensors="pt").to(device)
68
+
69
+ # Generate translation using the user's logic
70
+ output = model.generate(**inputs, max_length=128)
71
+
72
+ # Decode the output
73
+ return tokenizer.decode(output[0], skip_special_tokens=True)
74
 
75
+ # =====================
76
+ # 6️⃣ API Endpoints - Applying the Response Model
77
+ # =====================
78
+ @app.post("/translate", response_model=TranslationResponse) # <-- Fix remains here
79
  def translate(request: TranslationRequest):
80
+ try:
81
+ result = translateTo(request.text, request.src_lang, request.tgt_lang)
82
+ appVersionId = "App Version id = 2"
83
+
84
+ # Return the dictionary matching the TranslationResponse schema
85
+ return {"translation": result, "appVersionId": appVersionId}
86
+
87
+ except Exception as e:
88
+ print(f"An error occurred during translation: {e}")
89
+ # When raising an HTTPException, the response model is bypassed,
90
+ # and a standard JSON error is returned.
91
+ raise HTTPException(
92
+ status_code=500,
93
+ detail=f"Translation failed: {str(e)}"
94
+ )
95
 
96
  @app.get("/")
97
  def root():
98
+ return {"message": "API is running 🚀"}