vinzcyun commited on
Commit
747a950
·
verified ·
1 Parent(s): 8a7188c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +117 -36
main.py CHANGED
@@ -1,11 +1,13 @@
1
- from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import StreamingResponse
3
- from pydantic import BaseModel, Field
4
- from typing import Literal, Optional
 
5
  from google import genai
6
  from google.genai import types
7
  import io, wave, base64
8
  import logging
 
9
 
10
  # Set up logging to see more details about errors
11
  logging.basicConfig(level=logging.INFO)
@@ -14,78 +16,160 @@ logger = logging.getLogger(__name__)
14
  app = FastAPI(title="OpenAI-compatible TTS (Gemini via google-genai)")
15
 
16
  class OpenAITTSRequest(BaseModel):
17
- # OpenAI-style (để tương thích client)
18
- model: str = Field(..., description="OpenAI-style model (chỉ để tương thích)")
19
- voice: str = Field(..., description="Tên giọng TTS (Gemini prebuilt voice)")
20
  input: str = Field(..., description="Văn bản cần đọc")
21
  response_format: Optional[Literal["wav", "pcm"]] = Field(default="wav", description="Định dạng output")
22
- # Alternative field name for compatibility
23
  format: Optional[Literal["wav", "pcm"]] = Field(default=None, description="Định dạng output (alternative)")
24
- # Thông tin Gemini do user cung cấp (bắt buộc)
25
- gemini_api_key: str = Field(..., description="Google API key cho Gemini")
26
- gemini_model: str = Field(..., description="Tên model Gemini TTS (vd: gemini-2.5-flash-preview-tts)")
27
-
28
- # Optional OpenAI compatibility fields
29
  speed: Optional[float] = Field(default=1.0, ge=0.25, le=4.0, description="Tốc độ giọng nói")
 
 
 
 
 
 
 
 
30
 
31
- SR = 24000 # Gemini TTS trả PCM s16le 24kHz mono
32
 
33
  def pcm_to_wav_bytes(pcm: bytes, sr: int = SR) -> bytes:
34
  buf = io.BytesIO()
35
  with wave.open(buf, "wb") as wf:
36
  wf.setnchannels(1)
37
- wf.setsampwidth(2) # 16-bit (sampwidth=2 bytes)
38
  wf.setframerate(sr)
39
  wf.writeframes(pcm)
40
  return buf.getvalue()
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @app.post("/v1/audio/speech")
43
- async def audio_speech(body: OpenAITTSRequest):
44
- # Log the incoming request for debugging
45
- logger.info(f"Received TTS request: model={body.model}, voice={body.voice}")
 
 
 
 
46
 
47
- # Determine output format - check both fields for compatibility
48
- output_format = body.format or body.response_format or "wav"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  # Validate input text
51
- if not body.input or not body.input.strip():
52
  raise HTTPException(status_code=400, detail="Input text cannot be empty")
53
 
54
- # Khởi tạo client với API key do user cung cấp (không lưu trữ)
55
  try:
56
- client = genai.Client(api_key=body.gemini_api_key)
57
  except Exception as e:
58
  logger.error(f"Failed to initialize GenAI client: {e}")
59
  raise HTTPException(status_code=400, detail=f"Không khởi tạo được Google GenAI client: {e!s}")
60
 
61
- # Cấu hình TTS theo SDK chính thức
62
  config = types.GenerateContentConfig(
63
  response_modalities=["AUDIO"],
64
  speech_config=types.SpeechConfig(
65
  voice_config=types.VoiceConfig(
66
  prebuilt_voice_config=types.PrebuiltVoiceConfig(
67
- voice_name=body.voice
68
  )
69
  )
70
  )
71
  )
72
 
73
  try:
74
- logger.info(f"Calling Gemini TTS with model: {body.gemini_model}")
75
  resp = client.models.generate_content(
76
- model=body.gemini_model,
77
- contents=body.input,
78
  config=config
79
  )
80
  except Exception as e:
81
  logger.error(f"Gemini TTS API error: {e}")
82
- # Forward lỗi từ SDK/Upstream
83
  raise HTTPException(status_code=502, detail=f"Lỗi gọi Gemini TTS: {e!s}")
84
 
85
  # Lấy dữ liệu audio
86
  try:
87
  inline = resp.candidates[0].content.parts[0].inline_data
88
- data = inline.data # có thể là bytes hoặc base64 str (tuỳ version SDK)
89
  except (IndexError, AttributeError) as e:
90
  logger.error(f"Failed to extract audio data: {e}")
91
  raise HTTPException(status_code=500, detail="Không tìm thấy audio trong phản hồi Gemini")
@@ -93,7 +177,6 @@ async def audio_speech(body: OpenAITTSRequest):
93
  if isinstance(data, (bytes, bytearray)):
94
  pcm = bytes(data)
95
  else:
96
- # fallback: nếu SDK trả base64 string
97
  try:
98
  pcm = base64.b64decode(data)
99
  except Exception as e:
@@ -114,16 +197,14 @@ async def audio_speech(body: OpenAITTSRequest):
114
  headers={"Content-Disposition": 'inline; filename="speech.wav"'}
115
  )
116
 
117
- @app.exception_handler(422)
118
- async def validation_exception_handler(request, exc):
119
- logger.error(f"Validation error: {exc}")
120
- return HTTPException(status_code=422, detail=f"Validation error: {exc}")
121
-
122
  @app.get("/")
123
  def root():
124
  return {
125
  "ok": True,
126
- "usage": "POST /v1/audio/speech với {model, voice, input, response_format(wav|pcm), gemini_api_key, gemini_model}",
 
 
 
127
  "example": {
128
  "model": "tts-1",
129
  "voice": "en-US-Journey-F",
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
  from fastapi.responses import StreamingResponse
3
+ from fastapi.exceptions import RequestValidationError
4
+ from pydantic import BaseModel, Field, ValidationError
5
+ from typing import Literal, Optional, Any
6
  from google import genai
7
  from google.genai import types
8
  import io, wave, base64
9
  import logging
10
+ import json
11
 
12
  # Set up logging to see more details about errors
13
  logging.basicConfig(level=logging.INFO)
 
16
  app = FastAPI(title="OpenAI-compatible TTS (Gemini via google-genai)")
17
 
18
  class OpenAITTSRequest(BaseModel):
19
+ # OpenAI-style (để tương thích client) - tất cả đều optional với default
20
+ model: Optional[str] = Field(default="tts-1", description="OpenAI-style model")
21
+ voice: Optional[str] = Field(default="en-US-Journey-F", description="Tên giọng TTS")
22
  input: str = Field(..., description="Văn bản cần đọc")
23
  response_format: Optional[Literal["wav", "pcm"]] = Field(default="wav", description="Định dạng output")
 
24
  format: Optional[Literal["wav", "pcm"]] = Field(default=None, description="Định dạng output (alternative)")
 
 
 
 
 
25
  speed: Optional[float] = Field(default=1.0, ge=0.25, le=4.0, description="Tốc độ giọng nói")
26
+
27
+ # Thông tin Gemini
28
+ gemini_api_key: str = Field(..., description="Google API key cho Gemini")
29
+ gemini_model: Optional[str] = Field(default="gemini-2.0-flash-exp", description="Tên model Gemini TTS")
30
+
31
+ class Config:
32
+ # Allow extra fields để tránh lỗi khi client gửi thêm field không mong đợi
33
+ extra = "allow"
34
 
35
+ SR = 24000
36
 
37
  def pcm_to_wav_bytes(pcm: bytes, sr: int = SR) -> bytes:
38
  buf = io.BytesIO()
39
  with wave.open(buf, "wb") as wf:
40
  wf.setnchannels(1)
41
+ wf.setsampwidth(2)
42
  wf.setframerate(sr)
43
  wf.writeframes(pcm)
44
  return buf.getvalue()
45
 
46
+ @app.exception_handler(RequestValidationError)
47
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
48
+ # Log raw request body để debug
49
+ body = None
50
+ try:
51
+ body = await request.body()
52
+ body_str = body.decode('utf-8')
53
+ logger.error(f"Raw request body: {body_str}")
54
+
55
+ # Try to parse as JSON to see what we received
56
+ try:
57
+ json_body = json.loads(body_str)
58
+ logger.error(f"Parsed JSON: {json_body}")
59
+ except:
60
+ logger.error("Body is not valid JSON")
61
+
62
+ except Exception as e:
63
+ logger.error(f"Could not read request body: {e}")
64
+
65
+ logger.error(f"Validation error details: {exc.errors()}")
66
+ return HTTPException(
67
+ status_code=422,
68
+ detail={
69
+ "error": "Validation failed",
70
+ "details": exc.errors(),
71
+ "received_body": body.decode('utf-8') if body else None
72
+ }
73
+ )
74
+
75
+ # Alternative endpoint that accepts any JSON and logs it
76
+ @app.post("/v1/audio/speech/debug")
77
+ async def audio_speech_debug(request: Request):
78
+ body = await request.body()
79
+ content_type = request.headers.get("content-type", "")
80
+
81
+ logger.info(f"Debug endpoint - Content-Type: {content_type}")
82
+ logger.info(f"Debug endpoint - Raw body: {body.decode('utf-8')}")
83
+
84
+ try:
85
+ json_data = json.loads(body.decode('utf-8'))
86
+ logger.info(f"Debug endpoint - Parsed JSON: {json_data}")
87
+
88
+ # Try to create the model manually
89
+ try:
90
+ request_model = OpenAITTSRequest(**json_data)
91
+ logger.info(f"Debug endpoint - Model created successfully: {request_model}")
92
+ except ValidationError as ve:
93
+ logger.error(f"Debug endpoint - Validation error: {ve.errors()}")
94
+ return {"error": "validation_failed", "details": ve.errors()}
95
+ except Exception as e:
96
+ logger.error(f"Debug endpoint - Other error: {e}")
97
+ return {"error": "unknown_error", "details": str(e)}
98
+
99
+ except json.JSONDecodeError as e:
100
+ logger.error(f"Debug endpoint - JSON decode error: {e}")
101
+ return {"error": "invalid_json", "details": str(e)}
102
+
103
+ return {"status": "success", "message": "Request would be processed normally"}
104
+
105
  @app.post("/v1/audio/speech")
106
+ async def audio_speech(request: Request):
107
+ # Log incoming request
108
+ logger.info(f"Headers: {dict(request.headers)}")
109
+
110
+ # Read raw body first
111
+ body = await request.body()
112
+ logger.info(f"Raw body: {body.decode('utf-8')}")
113
 
114
+ try:
115
+ # Parse JSON manually first
116
+ json_data = json.loads(body.decode('utf-8'))
117
+ logger.info(f"Parsed JSON keys: {list(json_data.keys())}")
118
+
119
+ # Create Pydantic model
120
+ body_model = OpenAITTSRequest(**json_data)
121
+
122
+ except json.JSONDecodeError as e:
123
+ logger.error(f"JSON decode error: {e}")
124
+ raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
125
+ except ValidationError as e:
126
+ logger.error(f"Pydantic validation error: {e.errors()}")
127
+ raise HTTPException(status_code=422, detail={"validation_errors": e.errors()})
128
+ except Exception as e:
129
+ logger.error(f"Unexpected error during parsing: {e}")
130
+ raise HTTPException(status_code=400, detail=f"Request parsing error: {e}")
131
+
132
+ # Determine output format
133
+ output_format = body_model.format or body_model.response_format or "wav"
134
 
135
  # Validate input text
136
+ if not body_model.input or not body_model.input.strip():
137
  raise HTTPException(status_code=400, detail="Input text cannot be empty")
138
 
139
+ # Khởi tạo client với API key
140
  try:
141
+ client = genai.Client(api_key=body_model.gemini_api_key)
142
  except Exception as e:
143
  logger.error(f"Failed to initialize GenAI client: {e}")
144
  raise HTTPException(status_code=400, detail=f"Không khởi tạo được Google GenAI client: {e!s}")
145
 
146
+ # Cấu hình TTS
147
  config = types.GenerateContentConfig(
148
  response_modalities=["AUDIO"],
149
  speech_config=types.SpeechConfig(
150
  voice_config=types.VoiceConfig(
151
  prebuilt_voice_config=types.PrebuiltVoiceConfig(
152
+ voice_name=body_model.voice
153
  )
154
  )
155
  )
156
  )
157
 
158
  try:
159
+ logger.info(f"Calling Gemini TTS with model: {body_model.gemini_model}")
160
  resp = client.models.generate_content(
161
+ model=body_model.gemini_model,
162
+ contents=body_model.input,
163
  config=config
164
  )
165
  except Exception as e:
166
  logger.error(f"Gemini TTS API error: {e}")
 
167
  raise HTTPException(status_code=502, detail=f"Lỗi gọi Gemini TTS: {e!s}")
168
 
169
  # Lấy dữ liệu audio
170
  try:
171
  inline = resp.candidates[0].content.parts[0].inline_data
172
+ data = inline.data
173
  except (IndexError, AttributeError) as e:
174
  logger.error(f"Failed to extract audio data: {e}")
175
  raise HTTPException(status_code=500, detail="Không tìm thấy audio trong phản hồi Gemini")
 
177
  if isinstance(data, (bytes, bytearray)):
178
  pcm = bytes(data)
179
  else:
 
180
  try:
181
  pcm = base64.b64decode(data)
182
  except Exception as e:
 
197
  headers={"Content-Disposition": 'inline; filename="speech.wav"'}
198
  )
199
 
 
 
 
 
 
200
  @app.get("/")
201
  def root():
202
  return {
203
  "ok": True,
204
+ "usage": "POST /v1/audio/speech",
205
+ "debug_endpoint": "/v1/audio/speech/debug (để test request format)",
206
+ "required_fields": ["input", "gemini_api_key"],
207
+ "optional_fields": ["model", "voice", "response_format", "gemini_model", "speed"],
208
  "example": {
209
  "model": "tts-1",
210
  "voice": "en-US-Journey-F",