Markuspierre commited on
Commit
d9000ea
·
verified ·
1 Parent(s): 8aede60

Create asr-tts_service.py

Browse files
Files changed (1) hide show
  1. asr-tts_service.py +126 -0
asr-tts_service.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from flask import Flask, request, jsonify
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
+ import torch
5
+ import soundfile as sf
6
+ import numpy as np
7
+ import io
8
+ import re
9
+
10
+ # Parler-TTS imports
11
+ from parler_tts import ParlerTTSForConditionalGeneration
12
+
13
+ # Flask App
14
+ app = Flask(__name__)
15
+
16
+ # ASR Wolof
17
+ asr = pipeline("automatic-speech-recognition", model="bilalfaye/wav2vec2-large-mms-1b-wolof")
18
+
19
+ # Translation Wolof <-> French
20
+ model_name = "bilalfaye/nllb-200-distilled-600M-wo-fr-en"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ translation_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
23
+
24
+ fr_trans_model_name = "bilalfaye/nllb-200-distilled-600M-wo-fr-en"
25
+ tokenizer_fr_trans = AutoTokenizer.from_pretrained(fr_trans_model_name, use_fast=False)
26
+ fr_trans_model = AutoModelForSeq2SeqLM.from_pretrained(fr_trans_model_name)
27
+
28
+ # Parler-TTS Model (Wolof)
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ tts_model = ParlerTTSForConditionalGeneration.from_pretrained("CONCREE/Adia_TTS").to(device)
31
+ tts_tokenizer = AutoTokenizer.from_pretrained("CONCREE/Adia_TTS")
32
+ tts_description = "A professional, clear and composed voice, perfect for formal presentations"
33
+
34
+ # Helpers
35
+ def wolofToFrench(wolof_text):
36
+ tokenizer.src_lang = "wol_Latn"
37
+ inputs = tokenizer(wolof_text, return_tensors="pt", padding=True)
38
+ forced_bos = tokenizer.convert_tokens_to_ids("fra_Latn")
39
+ translated_tokens = translation_model.generate(**inputs, forced_bos_token_id=forced_bos, max_new_tokens=200)
40
+ return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
41
+
42
+ def frenchToWolof(fr_text):
43
+ tokenizer_fr_trans.src_lang = "fra_Latn"
44
+ inputs = tokenizer_fr_trans(fr_text, return_tensors="pt", padding=True)
45
+ forced_bos = tokenizer_fr_trans.convert_tokens_to_ids("wol_Latn")
46
+ translated_tokens = fr_trans_model.generate(**inputs, forced_bos_token_id=forced_bos, max_new_tokens=200)
47
+ return tokenizer_fr_trans.batch_decode(translated_tokens, skip_special_tokens=True)[0]
48
+
49
+ def convert_digits_in_text(text):
50
+ # Exemple simple : remplacer 0 par "zéro", 1 par "un", etc.
51
+ digits_map = {"0":"zéro","1":"un","2":"deux","3":"trois","4":"quatre","5":"cinq","6":"six","7":"sept","8":"huit","9":"neuf"}
52
+ for k,v in digits_map.items():
53
+ text = text.replace(k, v)
54
+ return text
55
+
56
+ def split_text(text, max_chars=170):
57
+ sentences = re.split(r'(?<=[.!?]) +', text)
58
+ chunks, current = [], ""
59
+ for s in sentences:
60
+ if len(current) + len(s) < max_chars:
61
+ current += " " + s if current else s
62
+ else:
63
+ chunks.append(current.strip())
64
+ current = s
65
+ if current:
66
+ chunks.append(current.strip())
67
+ return chunks
68
+
69
+
70
+ # Routes
71
+ @app.route("/", methods=["GET"])
72
+ def racine():
73
+ return "Flask Asr-Tts Service is running!"
74
+
75
+ @app.route("/accueil", methods=["GET"])
76
+ def accueil():
77
+ return "Flask Asr-Tts accueil endpoint is working!"
78
+
79
+ @app.route("/transcribe", methods=["POST"])
80
+ def transcribe():
81
+ if "file" not in request.files:
82
+ return jsonify({"error": "Aucun fichier audio trouvé"}), 400
83
+ audio_file = request.files["file"]
84
+ data, samplerate = sf.read(audio_file)
85
+ text = asr(np.array(data))["text"]
86
+ translated = wolofToFrench(text)
87
+ return translated or "Bonjour Adama"
88
+
89
+ # TTS Route
90
+ @app.route("/tts", methods=["POST"])
91
+ def tts_route():
92
+ payload = request.get_json()
93
+ if not payload or "text" not in payload:
94
+ return jsonify({"error": "Champ 'text' manquant"}), 400
95
+
96
+ text_fr = payload["text"]
97
+ text_wolof = frenchToWolof(text_fr)
98
+ text_wolof = convert_digits_in_text(text_wolof)
99
+ chunks = split_text(text_wolof)
100
+
101
+ print("TTS chunks:", chunks)
102
+
103
+ audio_segments = []
104
+
105
+ tts_input_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
106
+
107
+ with torch.no_grad():
108
+ for chunk in chunks:
109
+ prompt_ids = tts_tokenizer(chunk, return_tensors="pt").input_ids.to(device)
110
+ audio_tensor = tts_model.generate(input_ids=tts_input_ids, prompt_input_ids=prompt_ids)
111
+ audio_segments.append(audio_tensor.cpu().numpy().squeeze())
112
+
113
+ final_audio = np.concatenate(audio_segments)
114
+
115
+ buffer = io.BytesIO()
116
+ sf.write(buffer, final_audio, tts_model.config.sampling_rate, format="WAV")
117
+ buffer.seek(0)
118
+
119
+ audio_b64 = f"data:audio/wav;base64,{base64.b64encode(buffer.read()).decode('utf-8')}"
120
+
121
+ return jsonify({"audio": audio_b64})
122
+
123
+ # Run Flask
124
+
125
+ if __name__ == "__main__":
126
+ app.run(debug=False, host='0.0.0.0', port=7860, use_reloader=False)