minhpn's picture
fix: fix audio whipser
3ad9877
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, File, Form, Request, UploadFile
from groq import Groq, RateLimitError
from src.api.utils import (
check_user_rate_limit,
get_groq_keys,
track_key_usage,
user_request_tracker,
verify_user_license,
)
from src.common.logger import logger
from src.common.utils import response_error, response_success
router = APIRouter()
def _build_segments_from_response(response) -> List[Dict[str, Any]]:
"""Build segments list from Groq API response. Each segment contains its associated words."""
response = response.model_dump()
segments = response.get("segments", [])
words = response.get("words", [])
result = []
for segment in segments:
seg_start = segment.get("start")
seg_end = segment.get("end")
# Find and remove words that belong to this segment
segment_words = []
i = 0
while i < len(words):
word = words[i]
word_start = word.get("start")
word_end = word.get("end")
# Word belongs to segment if it overlaps
if word_start >= seg_start and word_end <= seg_end + 0.01:
segment_words.append(
{
"word": word.get("word", ""),
"start": word_start,
"end": word_end,
"probability": 1,
}
)
words.pop(i) # Remove word from list
else:
i += 1
# Build segment with embedded words
result.append(
{
"id": segment.get("id"),
"seek": segment.get("seek"),
"start": seg_start,
"end": seg_end,
"text": segment.get("text"),
"tokens": segment.get("tokens"),
"avg_logprob": segment.get("avg_logprob"),
"compression_ratio": segment.get("compression_ratio"),
"no_speech_prob": segment.get("no_speech_prob"),
"temperature": segment.get("temperature"),
"words": segment_words,
}
)
return result
@router.post("/asr")
async def transcribe_audio(
email: str = Form(...),
license_key: str = Form(...),
audio_file: UploadFile = File(...),
language: Optional[str] = None,
request: Request = None,
):
_, error = verify_user_license(email, license_key)
if error:
return response_error(error, f"License verification failed: {error}", 403)
if not check_user_rate_limit(email):
remaining = user_request_tracker[email]["reset_at"] - datetime.now(timezone.utc)
return response_error(
"USER_RATE_LIMIT",
f"Rate limit exceeded. Try again in {int(remaining.total_seconds())} seconds",
429,
)
api_keys = get_groq_keys()
if not api_keys:
return response_error("NO_API_KEYS", "No Groq API keys configured", 503)
audio_content = await audio_file.read()
last_error = None
for i, api_key in enumerate(api_keys):
try:
client = Groq(api_key=api_key)
params = {
"file": (audio_file.filename, audio_content),
"model": "whisper-large-v3",
"temperature": 0,
"response_format": "verbose_json",
"timestamp_granularities": ["word", "segment"],
}
if language:
params["language"] = language
response = client.audio.transcriptions.create(**params)
track_key_usage(api_key)
# Build segments with embedded words
segments = _build_segments_from_response(response)
return response_success(
{
"text": response.text,
"language": getattr(response, "language", language),
"duration": getattr(response, "duration", None),
"segments": segments,
}
)
except RateLimitError as e:
logger.error("Key rate-limited: switching...")
logger.error(str(e))
last_error = str(e)
continue
except Exception as e:
logger.error(str(e))
last_error = str(e)
continue
return response_error("ASR_FAILED", f"All API keys failed: {last_error}", 500)