Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import httpx | |
| import asyncio | |
| import logging | |
| from typing import Optional, Dict, Any, Callable, List | |
| from dataclasses import dataclass | |
| logger = logging.getLogger(__name__) | |
| class APIConfig: | |
| """API configuration""" | |
| base_url: str = "http://localhost:8001" | |
| timeout: float = 30.0 | |
| max_retries: int = 3 | |
| class APIClient: | |
| """ | |
| HTTP client for VoiceForge API | |
| """ | |
| def __init__(self, config: Optional[APIConfig] = None): | |
| """Initialize API client""" | |
| self.config = config or APIConfig() | |
| self._client: Optional[httpx.Client] = None | |
| self.token: Optional[str] = None | |
| def client(self) -> httpx.Client: | |
| """Lazy-load HTTP client""" | |
| if self._client is None: | |
| self._client = httpx.Client( | |
| base_url=self.config.base_url, | |
| timeout=self.config.timeout, | |
| ) | |
| if self.token: | |
| self._client.headers["Authorization"] = f"Bearer {self.token}" | |
| return self._client | |
| def close(self): | |
| """Close the client""" | |
| if self._client: | |
| self._client.close() | |
| self._client = None | |
| def set_token(self, token: str): | |
| """Set JWT token for authenticated requests""" | |
| self.token = token | |
| if self._client: | |
| self._client.headers["Authorization"] = f"Bearer {token}" | |
| def login(self, email, password) -> Dict[str, Any]: | |
| """Login and set token""" | |
| # Login uses form data | |
| response = self.client.post("/api/v1/auth/login", data={"username": email, "password": password}) | |
| response.raise_for_status() | |
| data = response.json() | |
| self.set_token(data["access_token"]) | |
| return data | |
| def register(self, email, password, name=None) -> Dict[str, Any]: | |
| """Register a new user""" | |
| payload = {"email": email, "password": password} | |
| if name: | |
| payload["name"] = name | |
| response = self.client.post("/api/v1/auth/register", json=payload) | |
| response.raise_for_status() | |
| return response.json() | |
| def get_me(self) -> Dict[str, Any]: | |
| """Get current user""" | |
| response = self.client.get("/api/v1/auth/me") | |
| response.raise_for_status() | |
| return response.json() | |
| # Health endpoints | |
| def health_check(self) -> Dict[str, Any]: | |
| """Check API health""" | |
| response = self.client.get("/health") | |
| response.raise_for_status() | |
| return response.json() | |
| # STT endpoints | |
| def get_languages(self) -> Dict[str, Any]: | |
| """Get supported languages""" | |
| response = self.client.get("/api/v1/stt/languages") | |
| response.raise_for_status() | |
| return response.json() | |
| def transcribe_file( | |
| self, | |
| file_content: bytes, | |
| filename: str, | |
| language: str = "en-US", | |
| enable_punctuation: bool = True, | |
| enable_timestamps: bool = True, | |
| enable_diarization: bool = False, | |
| speaker_count: Optional[int] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Transcribe an audio file | |
| Args: | |
| file_content: Audio file bytes | |
| filename: Original filename | |
| language: Language code | |
| enable_punctuation: Add automatic punctuation | |
| enable_timestamps: Include word timestamps | |
| enable_diarization: Identify speakers | |
| speaker_count: Expected number of speakers | |
| Returns: | |
| Transcription result dict | |
| """ | |
| files = {"file": (filename, file_content)} | |
| data = { | |
| "language": language, | |
| "enable_punctuation": str(enable_punctuation).lower(), | |
| "enable_word_timestamps": str(enable_timestamps).lower(), | |
| "enable_diarization": str(enable_diarization).lower(), | |
| } | |
| if speaker_count: | |
| data["speaker_count"] = str(speaker_count) | |
| response = self.client.post( | |
| "/api/v1/stt/upload", | |
| files=files, | |
| data=data, | |
| timeout=120.0, # Longer timeout for transcription | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| def transcribe_file_async( | |
| self, | |
| file_content: bytes, | |
| filename: str, | |
| language: str = "en-US", | |
| ) -> Dict[str, Any]: | |
| """ | |
| Asynchronously transcribe an audio file | |
| """ | |
| files = {"file": (filename, file_content)} | |
| data = {"language": language} | |
| response = self.client.post( | |
| "/api/v1/stt/async-upload", | |
| files=files, | |
| data=data, | |
| timeout=30.0, | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| def get_task_status(self, task_id: str) -> Dict[str, Any]: | |
| """ | |
| Get status of async task | |
| """ | |
| response = self.client.get(f"/api/v1/stt/tasks/{task_id}") | |
| response.raise_for_status() | |
| return response.json() | |
| # TTS endpoints | |
| def get_voices(self, language: Optional[str] = None) -> Dict[str, Any]: | |
| """Get available TTS voices""" | |
| if language: | |
| url = f"/api/v1/tts/voices/{language}" | |
| else: | |
| url = "/api/v1/tts/voices" | |
| response = self.client.get(url) | |
| response.raise_for_status() | |
| return response.json() | |
| def synthesize_speech( | |
| self, | |
| text: str, | |
| language: str = "en-US", | |
| voice: Optional[str] = None, | |
| speaking_rate: float = 1.0, | |
| pitch: float = 0.0, | |
| audio_encoding: str = "MP3", | |
| ) -> Dict[str, Any]: | |
| """ | |
| Synthesize text to speech | |
| Args: | |
| text: Text to synthesize | |
| language: Language code | |
| voice: Voice name (optional) | |
| speaking_rate: Speaking rate (0.25 to 4.0) | |
| pitch: Voice pitch (-20 to 20) | |
| audio_encoding: Output format (MP3, LINEAR16, OGG_OPUS) | |
| Returns: | |
| Synthesis result with base64 audio | |
| """ | |
| payload = { | |
| "text": text, | |
| "language": language, | |
| "speaking_rate": speaking_rate, | |
| "pitch": pitch, | |
| "audio_encoding": audio_encoding, | |
| } | |
| if voice: | |
| payload["voice"] = voice | |
| response = self.client.post( | |
| "/api/v1/tts/synthesize", | |
| json=payload, | |
| timeout=60.0, | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| def preview_voice(self, voice: str, text: Optional[str] = None) -> bytes: | |
| """ | |
| Get voice preview audio | |
| Args: | |
| voice: Voice name | |
| text: Optional preview text | |
| Returns: | |
| Audio bytes | |
| """ | |
| payload = {"voice": voice} | |
| if text: | |
| payload["text"] = text | |
| response = self.client.post("/api/v1/tts/preview", json=payload) | |
| response.raise_for_status() | |
| return response.content | |
| # Transcript endpoints | |
| def list_transcripts(self, skip: int = 0, limit: int = 100) -> List[Dict[str, Any]]: | |
| """List transcripts""" | |
| response = self.client.get(f"/api/v1/transcripts?skip={skip}&limit={limit}") | |
| response.raise_for_status() | |
| return response.json() | |
| def analyze_transcript(self, transcript_id: int) -> Dict[str, Any]: | |
| """Run NLP analysis on transcript""" | |
| response = self.client.post(f"/api/v1/transcripts/{transcript_id}/analyze") | |
| response.raise_for_status() | |
| return response.json() | |
| def export_transcript(self, transcript_id: int, format: str) -> bytes: | |
| """Export transcript to file""" | |
| response = self.client.get(f"/api/v1/transcripts/{transcript_id}/export?format={format}") | |
| response.raise_for_status() | |
| return response.content | |
| class AsyncAPIClient: | |
| """ | |
| Async HTTP client for VoiceForge API | |
| """ | |
| def __init__(self, config: Optional[APIConfig] = None): | |
| """Initialize async API client""" | |
| self.config = config or APIConfig() | |
| self._client: Optional[httpx.AsyncClient] = None | |
| def client(self) -> httpx.AsyncClient: | |
| """Lazy-load async HTTP client""" | |
| if self._client is None: | |
| self._client = httpx.AsyncClient( | |
| base_url=self.config.base_url, | |
| timeout=self.config.timeout, | |
| ) | |
| return self._client | |
| async def close(self): | |
| """Close the client""" | |
| if self._client: | |
| await self._client.aclose() | |
| self._client = None | |
| async def health_check(self) -> Dict[str, Any]: | |
| """Check API health""" | |
| response = await self.client.get("/health") | |
| response.raise_for_status() | |
| return response.json() | |
| async def transcribe_file( | |
| self, | |
| file_content: bytes, | |
| filename: str, | |
| language: str = "en-US", | |
| **options | |
| ) -> Dict[str, Any]: | |
| """Async version of transcribe_file""" | |
| files = {"file": (filename, file_content)} | |
| data = { | |
| "language": language, | |
| "enable_punctuation": str(options.get("enable_punctuation", True)).lower(), | |
| "enable_word_timestamps": str(options.get("enable_timestamps", True)).lower(), | |
| } | |
| response = await self.client.post( | |
| "/api/v1/stt/upload", | |
| files=files, | |
| data=data, | |
| timeout=120.0, | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| async def synthesize_speech( | |
| self, | |
| text: str, | |
| language: str = "en-US", | |
| **options | |
| ) -> Dict[str, Any]: | |
| """Async version of synthesize_speech""" | |
| payload = { | |
| "text": text, | |
| "language": language, | |
| "speaking_rate": options.get("speaking_rate", 1.0), | |
| "pitch": options.get("pitch", 0.0), | |
| "audio_encoding": options.get("audio_encoding", "MP3"), | |
| } | |
| if options.get("voice"): | |
| payload["voice"] = options["voice"] | |
| response = await self.client.post( | |
| "/api/v1/tts/synthesize", | |
| json=payload, | |
| timeout=60.0, | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| # Convenience function for Streamlit | |
| def get_api_client(base_url: str = "http://localhost:8001") -> APIClient: | |
| """Get API client with specified base URL""" | |
| config = APIConfig(base_url=base_url) | |
| return APIClient(config) | |