h-rand commited on
Commit
11e203c
·
verified ·
1 Parent(s): 92d4711

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Response, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from transformers import VitsModel, AutoTokenizer
4
+ import torch
5
+ import scipy.io.wavfile
6
+ import io
7
+ import os
8
+
9
+ app = FastAPI()
10
+
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ # --- CONFIGURATION ---
19
+ MODEL_ID = "facebook/mms-tts-fra"
20
+
21
+ model = None
22
+ tokenizer = None
23
+
24
+ print("⏳ Démarrage du serveur MMS Français...")
25
+
26
+ def load_model():
27
+ global model, tokenizer
28
+ try:
29
+ if model is not None: return True
30
+
31
+ print(f"📥 Chargement du modèle {MODEL_ID}...")
32
+ # CPU est suffisant pour MMS (très léger)
33
+ model = VitsModel.from_pretrained(MODEL_ID)
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
35
+
36
+ print("✅ Modèle MMS Français chargé !")
37
+ return True
38
+ except Exception as e:
39
+ print(f"❌ Erreur critique chargement : {e}")
40
+ return False
41
+
42
+ load_model()
43
+
44
+ @app.post("/tts")
45
+ async def generate_speech(data: dict):
46
+ if model is None:
47
+ if not load_model():
48
+ raise HTTPException(status_code=500, detail="Modèle indisponible")
49
+
50
+ text = data.get("text", "")
51
+ if not text:
52
+ raise HTTPException(status_code=400, detail="Texte vide")
53
+
54
+ try:
55
+ # 1. Tokenization
56
+ inputs = tokenizer(text, return_tensors="pt")
57
+
58
+ # 2. Inférence (Sans gradient = moins de RAM)
59
+ with torch.no_grad():
60
+ output = model(**inputs).waveform
61
+
62
+ # 3. Conversion Audio
63
+ audio_array = output.float().numpy().squeeze()
64
+ sample_rate = model.config.sampling_rate
65
+
66
+ # 4. Écriture WAV
67
+ buffer = io.BytesIO()
68
+ scipy.io.wavfile.write(buffer, rate=sample_rate, data=audio_array)
69
+ buffer.seek(0)
70
+
71
+ return Response(content=buffer.read(), media_type="audio/wav")
72
+
73
+ except Exception as e:
74
+ print(f"❌ Erreur génération : {e}")
75
+ return Response(content=str(e), status_code=500)
76
+
77
+ @app.get("/")
78
+ def home():
79
+ return {"status": "MMS French Ready 🇫🇷"}