h-rand commited on
Commit
eb7d099
·
verified ·
1 Parent(s): fa1d551

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -25
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, Response
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pocket_tts import TTSModel
4
  import scipy.io.wavfile
@@ -7,55 +7,96 @@ import torch
7
 
8
  app = FastAPI()
9
 
10
- # 1. Configuration CORS (Indispensable pour que votre site puisse appeler l'API)
11
  app.add_middleware(
12
  CORSMiddleware,
13
- allow_origins=["*"], # Permet à tous les sites d'accéder (pratique pour tester)
14
  allow_methods=["*"],
15
  allow_headers=["*"],
16
  )
17
 
18
- # 2. Chargement du Modèle au démarrage
19
- # On charge le modèle une seule fois ici pour ne pas attendre à chaque requête
20
- print("⏳ Chargement du modèle Pocket TTS (kyutai/pocket-tts)...")
21
  try:
22
- # CORRECTION ICI : On spécifie explicitement le repo HF
23
- tts_model = TTSModel.load_model("kyutai/pocket-tts")
24
-
25
- # On pré-charge la voix par défaut (Alba - Casual) pour gagner du temps
26
- default_voice_url = "hf://kyutai/tts-voices/alba-mackenna/casual.wav"
27
- default_voice_state = tts_model.get_state_for_audio_prompt(default_voice_url)
28
-
29
- print("✅ Modèle et Voix chargés avec succès !")
30
  except Exception as e:
31
- print(f"❌ ERREUR CRITIQUE au chargement : {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  @app.post("/tts")
34
  async def generate_speech(data: dict):
35
- # Récupération du texte envoyé par votre site
36
  text = data.get("text", "")
37
-
 
38
  if not text:
39
- return Response(content="Erreur: Texte vide", status_code=400)
40
 
41
- print(f"🗣️ Génération pour : {text[:30]}...")
42
 
43
  try:
44
- # Génération de l'audio
45
- audio_tensor = tts_model.generate_audio(default_voice_state, text)
 
 
 
46
 
47
- # Conversion du Tensor Torch -> Fichier WAV en mémoire (RAM)
48
  buffer = io.BytesIO()
49
  scipy.io.wavfile.write(buffer, tts_model.sample_rate, audio_tensor.numpy())
50
  buffer.seek(0)
51
 
52
- # Envoi du fichier audio binaire
53
  return Response(content=buffer.read(), media_type="audio/wav")
54
 
55
  except Exception as e:
56
- print(f"Erreur génération : {e}")
57
  return Response(content=str(e), status_code=500)
58
 
59
  @app.get("/")
60
  def home():
61
- return {"status": "Pocket TTS API is running 🚀"}
 
 
1
+ from fastapi import FastAPI, Response, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pocket_tts import TTSModel
4
  import scipy.io.wavfile
 
7
 
8
  app = FastAPI()
9
 
10
+ # Configuration CORS pour que votre site puisse appeler l'API
11
  app.add_middleware(
12
  CORSMiddleware,
13
+ allow_origins=["*"],
14
  allow_methods=["*"],
15
  allow_headers=["*"],
16
  )
17
 
18
+ # --- CHARGEMENT ---
19
+ print("⏳ Chargement de Pocket TTS...")
 
20
  try:
21
+ # Chargement du modèle sans arguments (utilise le défaut de la lib)
22
+ tts_model = TTSModel.load_model()
23
+ print("✅ Modèle chargé !")
 
 
 
 
 
24
  except Exception as e:
25
+ print(f"❌ ERREUR CRITIQUE CHARGEMENT MODÈLE: {e}")
26
+ # On ne quitte pas pour laisser le serveur démarrer et voir les logs,
27
+ # mais ça ne marchera pas sans modèle.
28
+
29
+ # Dictionnaire des voix officielles
30
+ # Cela permet d'envoyer juste "voice": "marius" depuis le Javascript
31
+ VOICES = {
32
+ "alba": "hf://kyutai/tts-voices/alba-mackenna/casual.wav",
33
+ "marius": "hf://kyutai/tts-voices/marius-kuntze/casual.wav",
34
+ "javert": "hf://kyutai/tts-voices/javert-mccall/casual.wav",
35
+ "jean": "hf://kyutai/tts-voices/jean-valjean/casual.wav",
36
+ "fantine": "hf://kyutai/tts-voices/fantine-becker/casual.wav",
37
+ "cosette": "hf://kyutai/tts-voices/cosette-moore/casual.wav",
38
+ "eponine": "hf://kyutai/tts-voices/eponine-frazier/casual.wav",
39
+ "azelma": "hf://kyutai/tts-voices/azelma-frazier/casual.wav"
40
+ }
41
+
42
+ # Cache pour garder les voix en mémoire (comme recommandé dans la doc)
43
+ loaded_voice_states = {}
44
+
45
+ def get_voice_state(voice_name="alba"):
46
+ # Si on a déjà chargé cette voix, on la renvoie direct (rapide)
47
+ if voice_name in loaded_voice_states:
48
+ return loaded_voice_states[voice_name]
49
+
50
+ # Sinon on la charge (lent la première fois)
51
+ print(f"📥 Chargement de la voix : {voice_name}")
52
+ full_path = VOICES.get(voice_name, VOICES["alba"])
53
+ try:
54
+ state = tts_model.get_state_for_audio_prompt(full_path)
55
+ loaded_voice_states[voice_name] = state
56
+ return state
57
+ except Exception as e:
58
+ print(f"Erreur chargement voix {voice_name}: {e}")
59
+ # Fallback sur Alba si erreur
60
+ if voice_name != "alba":
61
+ return get_voice_state("alba")
62
+ raise e
63
+
64
+ # Pré-chargement de la voix par défaut au démarrage
65
+ try:
66
+ get_voice_state("alba")
67
+ print("✅ Voix par défaut (Alba) prête !")
68
+ except:
69
+ pass
70
 
71
  @app.post("/tts")
72
  async def generate_speech(data: dict):
 
73
  text = data.get("text", "")
74
+ voice_name = data.get("voice", "alba") # Par défaut 'alba'
75
+
76
  if not text:
77
+ raise HTTPException(status_code=400, detail="Texte vide")
78
 
79
+ print(f"🗣️ Génération ({voice_name}): {text[:30]}...")
80
 
81
  try:
82
+ # Récupération de l'état de la voix
83
+ voice_state = get_voice_state(voice_name)
84
+
85
+ # Génération
86
+ audio_tensor = tts_model.generate_audio(voice_state, text)
87
 
88
+ # Conversion WAV
89
  buffer = io.BytesIO()
90
  scipy.io.wavfile.write(buffer, tts_model.sample_rate, audio_tensor.numpy())
91
  buffer.seek(0)
92
 
 
93
  return Response(content=buffer.read(), media_type="audio/wav")
94
 
95
  except Exception as e:
96
+ print(f"Erreur génération : {e}")
97
  return Response(content=str(e), status_code=500)
98
 
99
  @app.get("/")
100
  def home():
101
+ return {"status": "Pocket TTS API Ready", "available_voices": list(VOICES.keys())}
102
+