Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
|
|
| 3 |
from fastapi.security.api_key import APIKeyHeader, APIKey
|
| 4 |
from fastapi.responses import JSONResponse
|
| 5 |
from pydantic import BaseModel
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import io
|
| 8 |
import soundfile as sf
|
|
@@ -21,10 +22,9 @@ import time
|
|
| 21 |
import tempfile
|
| 22 |
|
| 23 |
# Import functions from other modules
|
| 24 |
-
from asr import transcribe, ASR_LANGUAGES
|
| 25 |
from tts import synthesize, TTS_LANGUAGES
|
| 26 |
from lid import identify
|
| 27 |
-
from asr import ASR_SAMPLING_RATE
|
| 28 |
|
| 29 |
# Configure logging
|
| 30 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -60,15 +60,18 @@ s3_client = boto3.client(
|
|
| 60 |
# Define request models
|
| 61 |
class AudioRequest(BaseModel):
|
| 62 |
audio: str # Base64 encoded audio or video data
|
| 63 |
-
language: str
|
| 64 |
|
| 65 |
class TTSRequest(BaseModel):
|
| 66 |
text: str
|
| 67 |
-
language: str
|
| 68 |
-
speed: float
|
| 69 |
|
| 70 |
class LanguageRequest(BaseModel):
|
| 71 |
-
language: str
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
async def get_api_key(api_key_header: str = Security(api_key_header)):
|
| 74 |
if api_key_header == API_KEY:
|
|
@@ -140,7 +143,13 @@ async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_
|
|
| 140 |
if sample_rate != ASR_SAMPLING_RATE:
|
| 141 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
| 142 |
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
processing_time = time.time() - start_time
|
| 145 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
| 146 |
except Exception as e:
|
|
@@ -156,7 +165,7 @@ async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_
|
|
| 156 |
)
|
| 157 |
|
| 158 |
@app.post("/transcribe_file")
|
| 159 |
-
async def transcribe_audio_file(file: UploadFile = File(...),
|
| 160 |
start_time = time.time()
|
| 161 |
try:
|
| 162 |
contents = await file.read()
|
|
@@ -169,7 +178,13 @@ async def transcribe_audio_file(file: UploadFile = File(...), language: str = ""
|
|
| 169 |
if sample_rate != ASR_SAMPLING_RATE:
|
| 170 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
| 171 |
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
processing_time = time.time() - start_time
|
| 174 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
| 175 |
except Exception as e:
|
|
@@ -189,19 +204,23 @@ async def synthesize_speech(request: TTSRequest, api_key: APIKey = Depends(get_a
|
|
| 189 |
start_time = time.time()
|
| 190 |
logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}")
|
| 191 |
try:
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
# Input validation
|
| 196 |
if not request.text:
|
| 197 |
raise ValueError("Text cannot be empty")
|
| 198 |
if lang_code not in TTS_LANGUAGES:
|
| 199 |
-
raise ValueError(f"Unsupported language: {
|
| 200 |
if not 0.5 <= request.speed <= 2.0:
|
| 201 |
raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}")
|
| 202 |
|
| 203 |
logger.info(f"Calling synthesize function with lang_code: {lang_code}")
|
| 204 |
-
result, filtered_text = synthesize(request.text,
|
| 205 |
logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'")
|
| 206 |
|
| 207 |
if result is None:
|
|
@@ -279,8 +298,6 @@ async def synthesize_speech(request: TTSRequest, api_key: APIKey = Depends(get_a
|
|
| 279 |
status_code=500,
|
| 280 |
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
|
| 281 |
)
|
| 282 |
-
finally:
|
| 283 |
-
logger.info("Synthesize request completed")
|
| 284 |
|
| 285 |
@app.post("/identify")
|
| 286 |
async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
|
|
@@ -328,22 +345,14 @@ async def identify_language_file(file: UploadFile = File(...), api_key: APIKey =
|
|
| 328 |
async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
| 329 |
start_time = time.time()
|
| 330 |
try:
|
| 331 |
-
if request.language
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
-
matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
| 335 |
-
processing_time = time.time() - start_time
|
| 336 |
-
return JSONResponse
|
| 337 |
-
matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
| 338 |
processing_time = time.time() - start_time
|
| 339 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
| 340 |
-
except ValueError as ve:
|
| 341 |
-
logger.error(f"ValueError in get_asr_languages: {str(ve)}", exc_info=True)
|
| 342 |
-
processing_time = time.time() - start_time
|
| 343 |
-
return JSONResponse(
|
| 344 |
-
status_code=400,
|
| 345 |
-
content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
|
| 346 |
-
)
|
| 347 |
except Exception as e:
|
| 348 |
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
|
| 349 |
error_details = {
|
|
@@ -360,19 +369,14 @@ async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(
|
|
| 360 |
async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
| 361 |
start_time = time.time()
|
| 362 |
try:
|
| 363 |
-
if request.language
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
-
matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
| 367 |
processing_time = time.time() - start_time
|
| 368 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
| 369 |
-
except ValueError as ve:
|
| 370 |
-
logger.error(f"ValueError in get_tts_languages: {str(ve)}", exc_info=True)
|
| 371 |
-
processing_time = time.time() - start_time
|
| 372 |
-
return JSONResponse(
|
| 373 |
-
status_code=400,
|
| 374 |
-
content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
|
| 375 |
-
)
|
| 376 |
except Exception as e:
|
| 377 |
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
|
| 378 |
error_details = {
|
|
|
|
| 3 |
from fastapi.security.api_key import APIKeyHeader, APIKey
|
| 4 |
from fastapi.responses import JSONResponse
|
| 5 |
from pydantic import BaseModel
|
| 6 |
+
from typing import Optional
|
| 7 |
import numpy as np
|
| 8 |
import io
|
| 9 |
import soundfile as sf
|
|
|
|
| 22 |
import tempfile
|
| 23 |
|
| 24 |
# Import functions from other modules
|
| 25 |
+
from asr import transcribe, ASR_LANGUAGES, ASR_SAMPLING_RATE
|
| 26 |
from tts import synthesize, TTS_LANGUAGES
|
| 27 |
from lid import identify
|
|
|
|
| 28 |
|
| 29 |
# Configure logging
|
| 30 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 60 |
# Define request models
|
| 61 |
class AudioRequest(BaseModel):
|
| 62 |
audio: str # Base64 encoded audio or video data
|
| 63 |
+
language: Optional[str] = None
|
| 64 |
|
| 65 |
class TTSRequest(BaseModel):
|
| 66 |
text: str
|
| 67 |
+
language: Optional[str] = None
|
| 68 |
+
speed: float = 1.0
|
| 69 |
|
| 70 |
class LanguageRequest(BaseModel):
|
| 71 |
+
language: Optional[str] = None
|
| 72 |
+
|
| 73 |
+
class TranscribeFileRequest(BaseModel):
|
| 74 |
+
language: Optional[str] = None
|
| 75 |
|
| 76 |
async def get_api_key(api_key_header: str = Security(api_key_header)):
|
| 77 |
if api_key_header == API_KEY:
|
|
|
|
| 143 |
if sample_rate != ASR_SAMPLING_RATE:
|
| 144 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
| 145 |
|
| 146 |
+
if request.language is None:
|
| 147 |
+
# If no language is provided, use language identification
|
| 148 |
+
identified_language = identify(audio_array)
|
| 149 |
+
result = transcribe(audio_array, identified_language)
|
| 150 |
+
else:
|
| 151 |
+
result = transcribe(audio_array, request.language)
|
| 152 |
+
|
| 153 |
processing_time = time.time() - start_time
|
| 154 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
| 155 |
except Exception as e:
|
|
|
|
| 165 |
)
|
| 166 |
|
| 167 |
@app.post("/transcribe_file")
|
| 168 |
+
async def transcribe_audio_file(file: UploadFile = File(...), request: TranscribeFileRequest = Depends(), api_key: APIKey = Depends(get_api_key)):
|
| 169 |
start_time = time.time()
|
| 170 |
try:
|
| 171 |
contents = await file.read()
|
|
|
|
| 178 |
if sample_rate != ASR_SAMPLING_RATE:
|
| 179 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
| 180 |
|
| 181 |
+
if request.language is None:
|
| 182 |
+
# If no language is provided, use language identification
|
| 183 |
+
identified_language = identify(audio_array)
|
| 184 |
+
result = transcribe(audio_array, identified_language)
|
| 185 |
+
else:
|
| 186 |
+
result = transcribe(audio_array, request.language)
|
| 187 |
+
|
| 188 |
processing_time = time.time() - start_time
|
| 189 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
| 190 |
except Exception as e:
|
|
|
|
| 204 |
start_time = time.time()
|
| 205 |
logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}")
|
| 206 |
try:
|
| 207 |
+
if request.language is None:
|
| 208 |
+
# If no language is provided, default to English
|
| 209 |
+
lang_code = "eng"
|
| 210 |
+
else:
|
| 211 |
+
# Extract the ISO code from the full language name
|
| 212 |
+
lang_code = request.language.split()[0].strip()
|
| 213 |
|
| 214 |
# Input validation
|
| 215 |
if not request.text:
|
| 216 |
raise ValueError("Text cannot be empty")
|
| 217 |
if lang_code not in TTS_LANGUAGES:
|
| 218 |
+
raise ValueError(f"Unsupported language: {lang_code}")
|
| 219 |
if not 0.5 <= request.speed <= 2.0:
|
| 220 |
raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}")
|
| 221 |
|
| 222 |
logger.info(f"Calling synthesize function with lang_code: {lang_code}")
|
| 223 |
+
result, filtered_text = synthesize(request.text, lang_code, request.speed)
|
| 224 |
logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'")
|
| 225 |
|
| 226 |
if result is None:
|
|
|
|
| 298 |
status_code=500,
|
| 299 |
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
|
| 300 |
)
|
|
|
|
|
|
|
| 301 |
|
| 302 |
@app.post("/identify")
|
| 303 |
async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
|
|
|
|
| 345 |
async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
| 346 |
start_time = time.time()
|
| 347 |
try:
|
| 348 |
+
if request.language is None or request.language == "":
|
| 349 |
+
# If no language is provided, return all languages
|
| 350 |
+
matching_languages = ASR_LANGUAGES
|
| 351 |
+
else:
|
| 352 |
+
matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
processing_time = time.time() - start_time
|
| 355 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
except Exception as e:
|
| 357 |
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
|
| 358 |
error_details = {
|
|
|
|
| 369 |
async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
| 370 |
start_time = time.time()
|
| 371 |
try:
|
| 372 |
+
if request.language is None or request.language == "":
|
| 373 |
+
# If no language is provided, return all languages
|
| 374 |
+
matching_languages = TTS_LANGUAGES
|
| 375 |
+
else:
|
| 376 |
+
matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
| 377 |
|
|
|
|
| 378 |
processing_time = time.time() - start_time
|
| 379 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
except Exception as e:
|
| 381 |
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
|
| 382 |
error_details = {
|