import os import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel # Note: Keep the imports together for clarity from transformers import ( NllbTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, ) from normalize_bm_input import normalize_bm_input from normalize_bm_output import normalize_bm_output # ===================== # 1️⃣ Environment / Cache # ===================== # Setting cache environment variables for Hugging Face os.environ["HF_HOME"] = "/tmp/hf" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf" os.environ["HF_DATASETS_CACHE"] = "/tmp/hf" os.makedirs("/tmp/hf", exist_ok=True) # ===================== # 2️⃣ Device # ===================== device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # ===================== # 3️⃣ Load Model & Tokenizer # ===================== # Charger le modèle et le tokenizer NLLB try: model_name = "Gaoussin/Bamalingua-2" tokenizer = NllbTokenizer.from_pretrained(model_name) # Move model to the selected device (CPU or GPU) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) print(f"Model '{model_name}' loaded successfully on {device}.") except Exception as e: print(f"Error loading model or tokenizer: {e}") # In a real application, you might exit or handle this more gracefully # ===================== # 4️⃣ FastAPI setup - Define Input and Output Schemas # ===================== app = FastAPI() # Input schema class TranslationRequest(BaseModel): text: str src_lang: str # e.g., "bam_Latn" tgt_lang: str # e.g., "fra_Latn" # ===================== # 5️⃣ Translation function - Restored to user's original logic # ===================== def translateTo(text, src, tgt): tokenizer.src_lang = src tokenizer.tgt_lang = tgt print({text, tokenizer.src_lang, tokenizer.tgt_lang}) # Prepare input for the model # We explicitly move the inputs to the same device as the model inputs = tokenizer(text, return_tensors="pt").to(device) # Generate translation using the user's logic output = model.generate(**inputs, max_length=128) # Decode the output return tokenizer.decode(output[0], skip_special_tokens=True) # ===================== # 6️⃣ API Endpoints - Applying the Response Model # ===================== @app.post("/translate") def translate(request: TranslationRequest): try: # --- 2. Core Translation --- result = translateTo(request.text, request.src_lang, request.tgt_lang) # --- 4. Final Output --- translation_list = [result, model_name] ### return [translation_list] except Exception as e: print(f"An error occurred during translation: {e}") raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") @app.get("/") def root(): return {"message": "API is running 🚀"}