bm-translator / main.py
Gaoussin's picture
Update main.py
9511ac6 verified
raw
history blame
3.49 kB
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# Note: Keep the imports together for clarity
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from normalize_bm_words import normalize_text
# =====================
# 1️⃣ Environment / Cache
# =====================
# Setting cache environment variables for Hugging Face
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)
# =====================
# 2️⃣ Device
# =====================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# =====================
# 3️⃣ Load Model & Tokenizer
# =====================
# Charger le modèle et le tokenizer NLLB
try:
model_name = "Gaoussin/bamalingua-4"
tokenizer = NllbTokenizer.from_pretrained(model_name)
# Move model to the selected device (CPU or GPU)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
print(f"Model '{model_name}' loaded successfully on {device}.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
# In a real application, you might exit or handle this more gracefully
# =====================
# 4️⃣ FastAPI setup - Define Input and Output Schemas
# =====================
app = FastAPI()
# Input schema
class TranslationRequest(BaseModel):
text: str
src_lang: str # e.g., "bam_Latn"
tgt_lang: str # e.g., "fra_Latn"
# Output schema (THE FIX: ensures both fields are returned)
class TranslationResponse(BaseModel):
"""
Ensures both the translated text and the app version ID are included
in the response JSON.
"""
translation: str
appVersionId: str
# =====================
# 5️⃣ Translation function - Restored to user's original logic
# =====================
def translateTo(text, src, tgt):
tokenizer.src_lang = src
tokenizer.tgt_lang = tgt
print(tokenizer.src_lang, tokenizer.tgt_lang)
# Prepare input for the model
# We explicitly move the inputs to the same device as the model
inputs = tokenizer(text, return_tensors="pt").to(device)
# Generate translation using the user's logic
output = model.generate(**inputs, max_length=128)
# Decode the output
return tokenizer.decode(output[0], skip_special_tokens=True)
# =====================
# 6️⃣ API Endpoints - Applying the Response Model
# =====================
@app.post("/translate", response_model=TranslationResponse) # <-- Fix remains here
def translate(request: TranslationRequest):
try:
# normalize_text from imported file
text = normalize_text(request.text)
result = translateTo(text, request.src_lang, request.tgt_lang)
appVersionId = "App Version id = 2"
# Return the dictionary matching the TranslationResponse schema
return {"translation": result, "appVersionId": appVersionId}
except Exception as e:
print(f"An error occurred during translation: {e}")
# When raising an HTTPException, the response model is bypassed,
# and a standard JSON error is returned.
raise HTTPException(
status_code=500,
detail=f"Translation failed: {str(e)}"
)
@app.get("/")
def root():
return {"message": "API is running 🚀"}