File size: 7,211 Bytes
c25fb8f
 
 
 
 
 
 
4806882
c25fb8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4806882
c25fb8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4806882
c25fb8f
4806882
c25fb8f
8b08d3c
4806882
c25fb8f
4806882
8b08d3c
4806882
8b08d3c
 
 
 
c25fb8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b08d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4806882
8b08d3c
 
 
 
 
 
 
 
 
 
4806882
8b08d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4806882
8b08d3c
 
 
c25fb8f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
Vertex AI client for TTS synthesis using Google Cloud AI Platform.
"""
import os
import json
import logging
import requests
from typing import Optional, Dict, Any, Tuple
from google.cloud import aiplatform
from google.oauth2 import service_account
from dotenv import load_dotenv

# Load environment variables from .env file (for local development)
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class VertexAIClient:
    """Client for interacting with Vertex AI TTS endpoint."""

    def __init__(self):
        """Initialize the Vertex AI client."""
        self.endpoint = None
        self.credentials = None
        self.initialized = False

    def _load_credentials(self) -> Optional[service_account.Credentials]:
        """
        Load credentials from auth_string environment variable.

        Returns:
            Credentials object or None if failed
        """
        try:
            auth_string = os.environ.get("auth_string")
            if not auth_string:
                logger.warning("auth_string environment variable not found")
                return None

            # Parse JSON credentials
            credentials_dict = json.loads(auth_string)
            credentials = service_account.Credentials.from_service_account_info(
                credentials_dict
            )
            logger.info("Successfully loaded credentials from auth_string")
            return credentials

        except json.JSONDecodeError as e:
            logger.error(f"Failed to parse auth_string JSON: {e}")
            return None
        except Exception as e:
            logger.error(f"Failed to load credentials: {e}")
            return None

    def initialize(self) -> bool:
        """
        Initialize Vertex AI and find the zipvoice_base_distill endpoint.

        Returns:
            True if initialization successful, False otherwise
        """
        if self.initialized:
            return True

        try:
            # Load credentials
            self.credentials = self._load_credentials()
            if not self.credentials:
                logger.error("Cannot initialize without credentials")
                return False

            # Initialize Vertex AI
            aiplatform.init(
                project="desivocalprod01",
                location="asia-south1",
                credentials=self.credentials,
            )
            logger.info("Vertex AI initialized for project desivocalprod01")

            # Find distill endpoint
            for endpoint in aiplatform.Endpoint.list():
                if endpoint.display_name == "zipvoice_base_distill":
                    self.endpoint = endpoint
                    logger.info(f"Found zipvoice_base_distill endpoint: {endpoint.resource_name}")
                    break

            # Check if endpoint is found
            if not self.endpoint:
                logger.error("zipvoice_base_distill endpoint not found in Vertex AI")
                return False

            self.initialized = True
            return True

        except Exception as e:
            logger.error(f"Failed to initialize Vertex AI: {e}")
            return False

    def get_voices(self) -> Tuple[bool, Optional[Dict[str, Any]]]:
        """
        Get available voices from local configuration file.

        Note: Vertex AI endpoint doesn't have a separate /voices API.
        Voices are configured in voices_config.json

        Returns:
            Tuple of (success, voices_dict)
            voices_dict format: {"voices": {"voice_id": {"name": "...", "gender": "..."}}}
        """
        try:
            # Try to load from voices_config.json in current directory
            import pathlib
            config_path = pathlib.Path(__file__).parent / "voices_config.json"

            if config_path.exists():
                logger.info(f"Loading voices from {config_path}")
                with open(config_path, "r") as f:
                    voices_data = json.load(f)
                logger.info(f"Successfully loaded {len(voices_data.get('voices', {}))} voices from config")
                return True, voices_data
            else:
                logger.warning(f"voices_config.json not found at {config_path}")
                # Return empty voices list
                return True, {"voices": {}}

        except Exception as e:
            logger.error(f"Failed to load voices config: {e}")
            return False, None

    def synthesize(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]:
        """
        Synthesize speech from text using Vertex AI distill endpoint.

        Args:
            text: Text to synthesize
            voice_id: Voice ID to use
            timeout: Request timeout in seconds

        Returns:
            Tuple of (success, audio_bytes, metrics)
        """
        if not self.initialized:
            if not self.initialize():
                return False, None, None

        try:
            logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id} using distill model")
            response = self.endpoint.raw_predict(
                body=json.dumps({
                    "text": text,
                    "voice_id": voice_id,
                    "model_type": "distill",
                }),
                headers={"Content-Type": "application/json"},
            )

            # Parse JSON response
            result = json.loads(response.text) if hasattr(response, 'text') else response
            logger.info(f"Vertex AI response: {result}")

            # Check if synthesis was successful
            if result.get("success"):
                audio_url = result.get("audio_url")
                metrics = result.get("metrics")

                if not audio_url:
                    logger.error("No audio_url in successful response")
                    return False, None, None

                # Download audio from URL
                logger.info(f"Downloading audio from: {audio_url}")
                audio_response = requests.get(audio_url, timeout=timeout)

                if audio_response.status_code == 200:
                    audio_data = audio_response.content
                    logger.info(f"Successfully downloaded audio ({len(audio_data)} bytes)")
                    return True, audio_data, metrics
                else:
                    logger.error(f"Failed to download audio: HTTP {audio_response.status_code}")
                    return False, None, None
            else:
                error_msg = result.get("message", "Unknown error")
                logger.error(f"Synthesis failed: {error_msg}")
                return False, None, None

        except Exception as e:
            logger.error(f"Failed to synthesize speech with Vertex AI: {e}")
            return False, None, None



# Global instance
_vertex_client = None


def get_vertex_client() -> VertexAIClient:
    """Get or create the global Vertex AI client instance."""
    global _vertex_client
    if _vertex_client is None:
        _vertex_client = VertexAIClient()
    return _vertex_client