Spaces:
Running
Running
| import os | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| # Note: Keep the imports together for clarity | |
| from transformers import ( | |
| NllbTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments, | |
| DataCollatorForSeq2Seq, | |
| ) | |
| from normalize_bm_input import normalize_bm_input | |
| from normalize_bm_output import normalize_bm_output | |
| # ===================== | |
| # 1️⃣ Environment / Cache | |
| # ===================== | |
| # Setting cache environment variables for Hugging Face | |
| os.environ["HF_HOME"] = "/tmp/hf" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf" | |
| os.environ["HF_DATASETS_CACHE"] = "/tmp/hf" | |
| os.makedirs("/tmp/hf", exist_ok=True) | |
| # ===================== | |
| # 2️⃣ Device | |
| # ===================== | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # ===================== | |
| # 3️⃣ Load Model & Tokenizer | |
| # ===================== | |
| # Charger le modèle et le tokenizer NLLB | |
| try: | |
| model_name = "Gaoussin/Bamalingua-2" | |
| tokenizer = NllbTokenizer.from_pretrained(model_name) | |
| # Move model to the selected device (CPU or GPU) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) | |
| print(f"Model '{model_name}' loaded successfully on {device}.") | |
| except Exception as e: | |
| print(f"Error loading model or tokenizer: {e}") | |
| # In a real application, you might exit or handle this more gracefully | |
| # ===================== | |
| # 4️⃣ FastAPI setup - Define Input and Output Schemas | |
| # ===================== | |
| app = FastAPI() | |
| # Input schema | |
| class TranslationRequest(BaseModel): | |
| text: str | |
| src_lang: str # e.g., "bam_Latn" | |
| tgt_lang: str # e.g., "fra_Latn" | |
| # ===================== | |
| # 5️⃣ Translation function - Restored to user's original logic | |
| # ===================== | |
| def translateTo(text, src, tgt): | |
| tokenizer.src_lang = src | |
| tokenizer.tgt_lang = tgt | |
| print({text, tokenizer.src_lang, tokenizer.tgt_lang}) | |
| # Prepare input for the model | |
| # We explicitly move the inputs to the same device as the model | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| # Generate translation using the user's logic | |
| output = model.generate(**inputs, max_length=128) | |
| # Decode the output | |
| return tokenizer.decode(output[0], skip_special_tokens=True) | |
| # ===================== | |
| # 6️⃣ API Endpoints - Applying the Response Model | |
| # ===================== | |
| def translate(request: TranslationRequest): | |
| try: | |
| # --- 2. Core Translation --- | |
| result = translateTo(request.text, request.src_lang, request.tgt_lang) | |
| # --- 4. Final Output --- | |
| translation_list = [result, model_name] | |
| ### | |
| return [translation_list] | |
| except Exception as e: | |
| print(f"An error occurred during translation: {e}") | |
| raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") | |
| def root(): | |
| return {"message": "API is running 🚀"} | |