Spaces:
Running
Running
| import os | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| # Note: Keep the imports together for clarity | |
| from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq | |
| from normalize_bm_words import normalize_text | |
| # ===================== | |
| # 1️⃣ Environment / Cache | |
| # ===================== | |
| # Setting cache environment variables for Hugging Face | |
| os.environ["HF_HOME"] = "/tmp/hf" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf" | |
| os.environ["HF_DATASETS_CACHE"] = "/tmp/hf" | |
| os.makedirs("/tmp/hf", exist_ok=True) | |
| # ===================== | |
| # 2️⃣ Device | |
| # ===================== | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # ===================== | |
| # 3️⃣ Load Model & Tokenizer | |
| # ===================== | |
| # Charger le modèle et le tokenizer NLLB | |
| try: | |
| model_name = "Gaoussin/bamalingua-4" | |
| tokenizer = NllbTokenizer.from_pretrained(model_name) | |
| # Move model to the selected device (CPU or GPU) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) | |
| print(f"Model '{model_name}' loaded successfully on {device}.") | |
| except Exception as e: | |
| print(f"Error loading model or tokenizer: {e}") | |
| # In a real application, you might exit or handle this more gracefully | |
| # ===================== | |
| # 4️⃣ FastAPI setup - Define Input and Output Schemas | |
| # ===================== | |
| app = FastAPI() | |
| # Input schema | |
| class TranslationRequest(BaseModel): | |
| text: str | |
| src_lang: str # e.g., "bam_Latn" | |
| tgt_lang: str # e.g., "fra_Latn" | |
| # Output schema (THE FIX: ensures both fields are returned) | |
| class TranslationResponse(BaseModel): | |
| """ | |
| Ensures both the translated text and the app version ID are included | |
| in the response JSON. | |
| """ | |
| translation: str | |
| appVersionId: str | |
| # ===================== | |
| # 5️⃣ Translation function - Restored to user's original logic | |
| # ===================== | |
| def translateTo(text, src, tgt): | |
| tokenizer.src_lang = src | |
| tokenizer.tgt_lang = tgt | |
| print(tokenizer.src_lang, tokenizer.tgt_lang) | |
| # Prepare input for the model | |
| # We explicitly move the inputs to the same device as the model | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| # Generate translation using the user's logic | |
| output = model.generate(**inputs, max_length=128) | |
| # Decode the output | |
| return tokenizer.decode(output[0], skip_special_tokens=True) | |
| # ===================== | |
| # 6️⃣ API Endpoints - Applying the Response Model | |
| # ===================== | |
| # <-- Fix remains here | |
| def translate(request: TranslationRequest): | |
| try: | |
| # normalize_text from imported file | |
| text = normalize_text(request.text) | |
| result = translateTo(text, request.src_lang, request.tgt_lang) | |
| appVersionId = "App Version id = 2" | |
| # Return the dictionary matching the TranslationResponse schema | |
| return {"translation": result, "appVersionId": appVersionId} | |
| except Exception as e: | |
| print(f"An error occurred during translation: {e}") | |
| # When raising an HTTPException, the response model is bypassed, | |
| # and a standard JSON error is returned. | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Translation failed: {str(e)}" | |
| ) | |
| def root(): | |
| return {"message": "API is running 🚀"} |