pluviouse commited on
Commit
dc3a626
·
verified ·
1 Parent(s): fdbb4ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -25,6 +25,13 @@ JAPANESE = {
25
  "onnx_dir": "./ONNX_net/G_jp/"
26
  }
27
 
 
 
 
 
 
 
 
28
  models_tts = []
29
  models_info = [
30
  TRILINGUAL,
@@ -62,10 +69,13 @@ def get_text(text, hps, is_symbol):
62
  text_norm = intersperse(text_norm, 0)
63
  return LongTensor(text_norm)
64
 
65
- def tts_process(text, speaker, speed, model_data, is_symbol):
66
  model = model_data["model"]
67
  hps = model_data["hps"]
68
  speaker_id = model_data["speaker_ids"][speaker]
 
 
 
69
  stn_tst = get_text(text, hps, is_symbol)
70
  with no_grad():
71
  x_tst = stn_tst.unsqueeze(0)
@@ -110,16 +120,24 @@ def generate(model):
110
  speed = float(data.get("speed", 1.0))
111
  is_symbol = data.get("is_symbol", False)
112
  speaker_id = data.get("speaker_id")
113
-
114
  if not text:
115
  return jsonify({"error": "Missing parameter 'text'"}), 400
116
 
117
  model_data = get_model_data(model)
118
  if not model_data:
119
  return jsonify({"error": "Model not found"}), 404
120
-
121
  speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() }
122
 
 
 
 
 
 
 
 
 
123
  if not speaker:
124
  if speaker_id is not None:
125
  speaker = speaker_ids.get(str(speaker_id), None)
@@ -132,7 +150,7 @@ def generate(model):
132
  return jsonify({"error": f"Speaker `{speaker}` not found"}), 404
133
 
134
  try:
135
- audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol)
136
  temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
137
  sf.write(temp_wav.name, audio, sampling_rate, format="wav")
138
  temp_wav.close()
 
25
  "onnx_dir": "./ONNX_net/G_jp/"
26
  }
27
 
28
+ language_marks = {
29
+ "JA": "[JA]",
30
+ "ZH": "[ZH]",
31
+ "ENG": "[EN]",
32
+ "MIX": "",
33
+ }
34
+
35
  models_tts = []
36
  models_info = [
37
  TRILINGUAL,
 
69
  text_norm = intersperse(text_norm, 0)
70
  return LongTensor(text_norm)
71
 
72
+ def tts_process(text, speaker, speed, model_data, is_symbol, language = None):
73
  model = model_data["model"]
74
  hps = model_data["hps"]
75
  speaker_id = model_data["speaker_ids"][speaker]
76
+ if language is not None:
77
+ text = language_marks[language] + text + language_marks[language]
78
+
79
  stn_tst = get_text(text, hps, is_symbol)
80
  with no_grad():
81
  x_tst = stn_tst.unsqueeze(0)
 
120
  speed = float(data.get("speed", 1.0))
121
  is_symbol = data.get("is_symbol", False)
122
  speaker_id = data.get("speaker_id")
123
+ language = data.get("lang")
124
  if not text:
125
  return jsonify({"error": "Missing parameter 'text'"}), 400
126
 
127
  model_data = get_model_data(model)
128
  if not model_data:
129
  return jsonify({"error": "Model not found"}), 404
130
+
131
  speaker_ids = { str(id): speaker for speaker, id in model_data["speaker_ids"].items() }
132
 
133
+ if language is not None:
134
+ is_ja = model.lower() == "japanese"
135
+ if is_ja:
136
+ language = None
137
+ elif not is_ja and language_marks.get(language) is None:
138
+ return jsonify({ "error": "language not available", "language": language_marks.keys() })
139
+
140
+
141
  if not speaker:
142
  if speaker_id is not None:
143
  speaker = speaker_ids.get(str(speaker_id), None)
 
150
  return jsonify({"error": f"Speaker `{speaker}` not found"}), 404
151
 
152
  try:
153
+ audio, sampling_rate = tts_process(text, speaker, speed, model_data, is_symbol, language)
154
  temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
155
  sf.write(temp_wav.name, audio, sampling_rate, format="wav")
156
  temp_wav.close()