ngananlp-v2-api / app.py
robsonrtp's picture
Upload app.py
00180d1 verified
import io
import os
import torch
from contextlib import asynccontextmanager
from typing import Optional
import edge_tts
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from pydantic import BaseModel, field_validator
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
# ─── Config ───────────────────────────────────────────────────────────────────
ADAPTER = os.getenv("ADAPTER", "robsonrtp/ngananlp-v2")
BASE_MODEL = "facebook/nllb-200-distilled-1.3B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_TOKEN = os.getenv("HF_TOKEN", None)
# Zulu neural voice — Bantu language, sounds natural for Angolan Bantu content
DEFAULT_TTS_VOICE = os.getenv("TTS_VOICE", "zu-ZA-ThandoNeural")
TTS_VOICE_MAP: dict[str, str] = {
"por_Latn": "pt-PT-RaquelNeural",
}
LANGUAGES = {
"por_Latn": "Português",
"lin_Latn": "Lingala",
"umb_Latn": "Umbundu",
"kmb_Latn": "Kimbundu",
"cjk_Latn": "Tchokwe",
"lue_Latn": "Luvale",
}
# Direcções com baixa confiança (X → lue_Latn)
LOW_CONFIDENCE_TARGETS = {"lue_Latn"}
# ─── Model loading ────────────────────────────────────────────────────────────
tokenizer = None
model = None
def load_model():
global tokenizer, model
print(f"Loading tokenizer from {ADAPTER}...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(ADAPTER, token=HF_TOKEN)
print(f"Loading base model {BASE_MODEL}...", flush=True)
base = AutoModelForSeq2SeqLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
base.resize_token_embeddings(len(tokenizer))
base = base.to(DEVICE)
print(f"Loading LoRA adapter from {ADAPTER}...", flush=True)
model = PeftModel.from_pretrained(base, ADAPTER, token=HF_TOKEN).eval()
print("Model ready.", flush=True)
@asynccontextmanager
async def lifespan(app: FastAPI):
load_model()
yield
# ─── App ──────────────────────────────────────────────────────────────────────
app = FastAPI(
title="NganaNLP v2",
description="Multilingual translation API for Portuguese and Angolan Bantu languages.\n\nExample: POST /translate `{\"text\": \"Como tomar boas decisões sobre educação?\", \"src_lang\": \"por_Latn\", \"tgt_lang\": \"kmb_Latn\"}`",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ─── Schemas ──────────────────────────────────────────────────────────────────
class TranslateRequest(BaseModel):
text: str
src_lang: str
tgt_lang: str
num_beams: Optional[int] = 4
@field_validator("src_lang", "tgt_lang")
@classmethod
def validate_lang(cls, v):
if v not in LANGUAGES:
raise ValueError(f"Unsupported language '{v}'. Use one of: {list(LANGUAGES)}")
return v
@field_validator("num_beams")
@classmethod
def validate_beams(cls, v):
if not (1 <= v <= 8):
raise ValueError("num_beams must be between 1 and 8")
return v
class SpeakRequest(BaseModel):
text: str
lang: Optional[str] = None # language code (e.g. por_Latn) to auto-select voice
voice: Optional[str] = None # explicit override, takes precedence over lang
class TranslateResponse(BaseModel):
translation: str
src_lang: str
tgt_lang: str
low_confidence: bool
# ─── Endpoints ────────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {
"name": "NganaNLP v2",
"description": "Multilingual translation API for Portuguese and Angolan Bantu languages",
"message": "A language unheard is a thought untranslated. We build bridges.",
"languages": LANGUAGES,
"endpoints": {
"translate": "POST /translate",
"speak": "POST /speak",
"languages": "GET /languages",
"docs": "GET /docs",
"health": "GET /health",
},
}
@app.get("/health")
def health():
return {"status": "ok", "device": DEVICE, "model": ADAPTER}
@app.get("/languages")
def languages():
return {
"languages": [
{"code": code, "name": name, "unreliable_as_target": code in LOW_CONFIDENCE_TARGETS}
for code, name in LANGUAGES.items()
]
}
@app.post("/translate", response_model=TranslateResponse)
def translate(req: TranslateRequest):
if req.src_lang == req.tgt_lang:
raise HTTPException(status_code=400, detail="src_lang and tgt_lang must be different")
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
tokenizer.src_lang = req.src_lang
inputs = tokenizer(
req.text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=192,
).to(DEVICE)
forced_bos_token_id = tokenizer.convert_tokens_to_ids(req.tgt_lang)
with torch.no_grad():
output = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=256,
num_beams=req.num_beams,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
)
translation = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return TranslateResponse(
translation=translation,
src_lang=req.src_lang,
tgt_lang=req.tgt_lang,
low_confidence=req.tgt_lang in LOW_CONFIDENCE_TARGETS,
)
# ─── TTS ──────────────────────────────────────────────────────────────────────
@app.post(
"/speak",
summary="Text-to-speech via edge-tts (Microsoft Neural TTS)",
responses={200: {"content": {"audio/mpeg": {}}}},
)
async def speak(req: SpeakRequest):
"""Convert text to speech and return an MP3 audio file."""
if not req.text.strip():
raise HTTPException(status_code=400, detail="text must not be empty")
voice = req.voice or TTS_VOICE_MAP.get(req.lang or "", DEFAULT_TTS_VOICE)
buf = io.BytesIO()
try:
communicate = edge_tts.Communicate(req.text, voice)
async for chunk in communicate.stream():
if chunk["type"] == "audio":
buf.write(chunk["data"])
except Exception as e:
raise HTTPException(status_code=502, detail=f"TTS error: {e}")
audio_bytes = buf.getvalue()
if not audio_bytes:
raise HTTPException(status_code=502, detail="TTS returned no audio")
return Response(
content=audio_bytes,
media_type="audio/mpeg",
headers={"Content-Disposition": "inline; filename=speech.mp3"},
)