File size: 2,889 Bytes
e2da8d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
from fastapi import FastAPI, UploadFile
from fastapi.responses import JSONResponse
from transformers import (
    VitsModel, AutoTokenizer,
    Wav2Vec2ForCTC, AutoProcessor,
    NllbTokenizer, AutoModelForSeq2SeqLM
)
import torch, scipy, base64, numpy as np
import soundfile as sf
from io import BytesIO

app = FastAPI()

# ─── Chargement des modèles ───
print("Chargement TTS Dida Yocoboué...")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-gud")
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-gud")

print("Chargement ASR...")
asr_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")
asr_processor.tokenizer.set_target_lang("gud")
asr_model.load_adapter("gud")

print("Chargement Traduction NLLB...")
nllb_tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
nllb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")

# ─── TTS ───
@app.post("/api/tts")
async def text_to_speech(payload: dict):
    text = payload["text"]
    inputs = tts_tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        waveform = tts_model(**inputs).waveform.squeeze()
    buffer = BytesIO()
    scipy.io.wavfile.write(buffer, rate=tts_model.config.sampling_rate, data=waveform.numpy())
    audio_b64 = base64.b64encode(buffer.getvalue()).decode()
    return {"audio_base64": audio_b64, "sample_rate": tts_model.config.sampling_rate}

# ─── ASR ───
@app.post("/api/asr")
async def speech_to_text(file: UploadFile):
    audio_bytes = await file.read()
    audio_array, sr = sf.read(BytesIO(audio_bytes))
    if sr != 16000:
        import librosa
        audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000)
    inputs = asr_processor(audio_array, sampling_rate=16000, return_tensors="pt")
    with torch.no_grad():
        logits = asr_model(**inputs).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = asr_processor.batch_decode(predicted_ids)[0]
    return {"transcription": transcription}

# ─── Traduction ───
@app.post("/api/translate")
async def translate(payload: dict):
    text = payload["text"]
    source_lang = payload.get("source_lang", "fra_Latn")
    target_lang = payload.get("target_lang", "fra_Latn")
    inputs = nllb_tokenizer(text, return_tensors="pt", src_lang=source_lang)
    translated = nllb_model.generate(
        **inputs,
        forced_bos_token_id=nllb_tokenizer.lang_code_to_id[target_lang]
    )
    result = nllb_tokenizer.decode(translated[0], skip_special_tokens=True)
    return {"translation": result}

# ─── Health check ───
@app.get("/")
async def root():
    return {"status": "ok", "message": "API Dida opérationnelle !"}