bm-translator / main.py
Gaoussin's picture
Update main.py
8b2a768 verified
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# Note: Keep the imports together for clarity
from transformers import (
NllbTokenizer,
AutoModelForSeq2SeqLM,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
DataCollatorForSeq2Seq,
)
from normalize_bm_input import normalize_bm_input
from normalize_bm_output import normalize_bm_output
# =====================
# 1️⃣ Environment / Cache
# =====================
# Setting cache environment variables for Hugging Face
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)
# =====================
# 2️⃣ Device
# =====================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# =====================
# 3️⃣ Load Model & Tokenizer
# =====================
# Charger le modèle et le tokenizer NLLB
try:
model_name = "Gaoussin/Bamalingua-2"
tokenizer = NllbTokenizer.from_pretrained(model_name)
# Move model to the selected device (CPU or GPU)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
print(f"Model '{model_name}' loaded successfully on {device}.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
# In a real application, you might exit or handle this more gracefully
# =====================
# 4️⃣ FastAPI setup - Define Input and Output Schemas
# =====================
app = FastAPI()
# Input schema
class TranslationRequest(BaseModel):
text: str
src_lang: str # e.g., "bam_Latn"
tgt_lang: str # e.g., "fra_Latn"
# =====================
# 5️⃣ Translation function - Restored to user's original logic
# =====================
def translateTo(text, src, tgt):
tokenizer.src_lang = src
tokenizer.tgt_lang = tgt
print({text, tokenizer.src_lang, tokenizer.tgt_lang})
# Prepare input for the model
# We explicitly move the inputs to the same device as the model
inputs = tokenizer(text, return_tensors="pt").to(device)
# Generate translation using the user's logic
output = model.generate(**inputs, max_length=128)
# Decode the output
return tokenizer.decode(output[0], skip_special_tokens=True)
# =====================
# 6️⃣ API Endpoints - Applying the Response Model
# =====================
@app.post("/translate")
def translate(request: TranslationRequest):
try:
# --- 2. Core Translation ---
result = translateTo(request.text, request.src_lang, request.tgt_lang)
# --- 4. Final Output ---
translation_list = [result, model_name]
###
return [translation_list]
except Exception as e:
print(f"An error occurred during translation: {e}")
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
@app.get("/")
def root():
return {"message": "API is running 🚀"}