hivecorp commited on
Commit
5b0bfda
·
verified ·
1 Parent(s): 51a095c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -140
app.py CHANGED
@@ -2,19 +2,18 @@ from fastapi import FastAPI, HTTPException, BackgroundTasks
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import FileResponse, JSONResponse
4
  from pydantic import BaseModel
5
- import gradio as gr
6
- from typing import Optional, Dict, Any
7
  import time
8
  import uvicorn
9
  from datetime import datetime
10
  import psutil
11
  import asyncio
12
- from app import (
13
- generate_accurate_srt,
14
- voice_options,
15
- TTSError,
16
- FileManager
17
- )
18
 
19
  # Initialize FastAPI app
20
  app = FastAPI(
@@ -32,6 +31,27 @@ app.add_middleware(
32
  allow_headers=["*"],
33
  )
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Global state management
36
  class ProcessingState:
37
  def __init__(self):
@@ -43,12 +63,9 @@ state = ProcessingState()
43
  # Pydantic models
44
  class TTSRequest(BaseModel):
45
  text: str
46
- voice: str = "Jenny Female"
47
  pitch: int = 0
48
  rate: int = 0
49
- words_per_line: int = 6
50
- lines_per_segment: int = 2
51
- parallel_processing: bool = True
52
 
53
  class HealthResponse(BaseModel):
54
  status: str
@@ -57,20 +74,38 @@ class HealthResponse(BaseModel):
57
  memory_usage: float
58
  active_jobs: int
59
 
60
- # Progress callback handler
61
- async def update_job_progress(job_id: str, progress: float, status: str):
62
- state.active_jobs[job_id].update({
63
- "progress": progress,
64
- "status": status,
65
- "last_update": datetime.now().isoformat()
66
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # API endpoints
69
  @app.post("/api/v1/tts")
70
  async def create_tts(request: TTSRequest, background_tasks: BackgroundTasks):
71
  job_id = f"job_{int(time.time())}_{hash(request.text)}"
72
 
73
- # Initialize job status
74
  state.active_jobs[job_id] = {
75
  "id": job_id,
76
  "status": "queued",
@@ -82,26 +117,20 @@ async def create_tts(request: TTSRequest, background_tasks: BackgroundTasks):
82
 
83
  async def process_tts():
84
  try:
85
- # Format pitch and rate strings
86
  pitch_str = f"{request.pitch:+d}Hz"
87
  rate_str = f"{request.rate:+d}%"
88
 
89
- srt_path, audio_path = await generate_accurate_srt(
90
  request.text,
91
- voice_options[request.voice],
92
  rate_str,
93
- pitch_str,
94
- request.words_per_line,
95
- request.lines_per_segment,
96
- progress_callback=lambda p, s: update_job_progress(job_id, p, s),
97
- parallel=request.parallel_processing
98
  )
99
 
100
  state.active_jobs[job_id].update({
101
  "status": "completed",
102
  "progress": 1.0,
103
  "result": {
104
- "srt_path": srt_path,
105
  "audio_path": audio_path
106
  }
107
  })
@@ -112,7 +141,6 @@ async def create_tts(request: TTSRequest, background_tasks: BackgroundTasks):
112
  })
113
 
114
  background_tasks.add_task(process_tts)
115
-
116
  return {"job_id": job_id, "status": "queued"}
117
 
118
  @app.get("/api/v1/status/{job_id}")
@@ -121,8 +149,8 @@ async def get_job_status(job_id: str):
121
  raise HTTPException(status_code=404, detail="Job not found")
122
  return state.active_jobs[job_id]
123
 
124
- @app.get("/api/v1/download/{job_id}/{file_type}")
125
- async def download_file(job_id: str, file_type: str):
126
  if job_id not in state.active_jobs:
127
  raise HTTPException(status_code=404, detail="Job not found")
128
 
@@ -130,13 +158,10 @@ async def download_file(job_id: str, file_type: str):
130
  if job["status"] != "completed":
131
  raise HTTPException(status_code=400, detail="Job not completed")
132
 
133
- if file_type not in ["audio", "srt"]:
134
- raise HTTPException(status_code=400, detail="Invalid file type")
135
-
136
- file_path = job["result"][f"{file_type}_path"]
137
  return FileResponse(
138
  file_path,
139
- filename=f"tts_output.{file_type}"
140
  )
141
 
142
  @app.get("/api/v1/health", response_model=HealthResponse)
@@ -149,109 +174,9 @@ async def health_check():
149
  "active_jobs": len(state.active_jobs)
150
  }
151
 
152
- @app.delete("/api/v1/jobs/{job_id}")
153
- async def cancel_job(job_id: str):
154
- if job_id not in state.active_jobs:
155
- raise HTTPException(status_code=404, detail="Job not found")
156
-
157
- job = state.active_jobs[job_id]
158
- if job["status"] in ["completed", "failed"]:
159
- del state.active_jobs[job_id]
160
- return {"status": "deleted"}
161
-
162
- job["status"] = "cancelled"
163
- return {"status": "cancelled"}
164
-
165
- # Initialize Gradio interface
166
- with gr.Blocks() as gradio_app:
167
- gr.Markdown("# Advanced TTS with Configurable SRT Generation")
168
- gr.Markdown("Generate perfectly synchronized audio and subtitles with natural speech patterns.")
169
-
170
- with gr.Row():
171
- with gr.Column(scale=3):
172
- text_input = gr.Textbox(label="Enter Text", lines=10, placeholder="Enter your text here...")
173
-
174
- with gr.Column(scale=2):
175
- voice_dropdown = gr.Dropdown(
176
- label="Select Voice",
177
- choices=list(voice_options.keys()),
178
- value="Jenny Female"
179
- )
180
- pitch_slider = gr.Slider(
181
- label="Pitch Adjustment (Hz)",
182
- minimum=-10,
183
- maximum=10,
184
- value=0,
185
- step=1
186
- )
187
- rate_slider = gr.Slider(
188
- label="Rate Adjustment (%)",
189
- minimum=-25,
190
- maximum=25,
191
- value=0,
192
- step=1
193
- )
194
-
195
- with gr.Row():
196
- with gr.Column():
197
- words_per_line = gr.Slider(
198
- label="Words per Line",
199
- minimum=3,
200
- maximum=12,
201
- value=6,
202
- step=1,
203
- info="Controls how many words appear on each line of the subtitle"
204
- )
205
- with gr.Column():
206
- lines_per_segment = gr.Slider(
207
- label="Lines per Segment",
208
- minimum=1,
209
- maximum=4,
210
- value=2,
211
- step=1,
212
- info="Controls how many lines appear in each subtitle segment"
213
- )
214
- with gr.Column():
215
- parallel_processing = gr.Checkbox(
216
- label="Enable Parallel Processing",
217
- value=True,
218
- info="Process multiple segments simultaneously for faster conversion"
219
- )
220
-
221
- submit_btn = gr.Button("Generate Audio & Subtitles")
222
- error_output = gr.Textbox(label="Status", visible=False)
223
-
224
- with gr.Row():
225
- with gr.Column():
226
- audio_output = gr.Audio(label="Preview Audio")
227
- with gr.Column():
228
- srt_file = gr.File(label="Download SRT")
229
- audio_file = gr.File(label="Download Audio")
230
-
231
- submit_btn.click(
232
- fn=process_text_with_progress,
233
- inputs=[
234
- text_input,
235
- pitch_slider,
236
- rate_slider,
237
- voice_dropdown,
238
- words_per_line,
239
- lines_per_segment,
240
- parallel_processing
241
- ],
242
- outputs=[
243
- srt_file,
244
- audio_file,
245
- audio_output,
246
- error_output,
247
- error_output
248
- ],
249
- api_name="generate"
250
- )
251
-
252
- # Mount Gradio app to FastAPI
253
- app = gr.mount_gradio_app(app, gradio_app, path="/")
254
 
255
- # Start the FastAPI server
256
  if __name__ == "__main__":
257
  uvicorn.run("fastapi_app:app", host="0.0.0.0", port=8000, reload=True)
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import FileResponse, JSONResponse
4
  from pydantic import BaseModel
5
+ from typing import Optional, Dict, Any, List, Tuple
 
6
  import time
7
  import uvicorn
8
  from datetime import datetime
9
  import psutil
10
  import asyncio
11
+ import edge_tts
12
+ from pydub import AudioSegment
13
+ import os
14
+ import uuid
15
+ import tempfile
16
+ from concurrent.futures import ThreadPoolExecutor
17
 
18
  # Initialize FastAPI app
19
  app = FastAPI(
 
31
  allow_headers=["*"],
32
  )
33
 
34
+ # Core functionality (moved from app.py)
35
+ class TTSError(Exception):
36
+ pass
37
+
38
+ class FileManager:
39
+ def __init__(self):
40
+ self.temp_dir = tempfile.mkdtemp(prefix="tts_api_")
41
+ self.output_files = []
42
+
43
+ def get_temp_path(self, prefix: str) -> str:
44
+ return os.path.join(self.temp_dir, f"{prefix}_{uuid.uuid4()}")
45
+
46
+ def cleanup_old_files(self):
47
+ for path in self.output_files[:-5]: # Keep only last 5 files
48
+ try:
49
+ if os.path.exists(path):
50
+ os.remove(path)
51
+ except Exception:
52
+ pass
53
+ self.output_files = self.output_files[-5:]
54
+
55
  # Global state management
56
  class ProcessingState:
57
  def __init__(self):
 
63
  # Pydantic models
64
  class TTSRequest(BaseModel):
65
  text: str
66
+ voice: str = "en-US-JennyNeural"
67
  pitch: int = 0
68
  rate: int = 0
 
 
 
69
 
70
  class HealthResponse(BaseModel):
71
  status: str
 
74
  memory_usage: float
75
  active_jobs: int
76
 
77
+ # Voice options dictionary (simplified)
78
+ voice_options = {
79
+ "Jenny": "en-US-JennyNeural",
80
+ "Guy": "en-US-GuyNeural",
81
+ "Ana": "en-US-AnaNeural",
82
+ "Aria": "en-US-AriaNeural"
83
+ }
84
+
85
+ async def generate_tts(text: str, voice: str, rate: str, pitch: str) -> Tuple[str, str]:
86
+ """Core TTS generation function"""
87
+ try:
88
+ audio_path = state.file_manager.get_temp_path("audio") + ".mp3"
89
+
90
+ tts = edge_tts.Communicate(text, voice, rate=rate, pitch=pitch)
91
+ await tts.save(audio_path)
92
+
93
+ if not os.path.exists(audio_path):
94
+ raise TTSError("Failed to generate audio file")
95
+
96
+ state.file_manager.output_files.append(audio_path)
97
+ state.file_manager.cleanup_old_files()
98
+
99
+ return audio_path
100
+
101
+ except Exception as e:
102
+ raise TTSError(f"TTS generation failed: {str(e)}")
103
 
104
  # API endpoints
105
  @app.post("/api/v1/tts")
106
  async def create_tts(request: TTSRequest, background_tasks: BackgroundTasks):
107
  job_id = f"job_{int(time.time())}_{hash(request.text)}"
108
 
 
109
  state.active_jobs[job_id] = {
110
  "id": job_id,
111
  "status": "queued",
 
117
 
118
  async def process_tts():
119
  try:
 
120
  pitch_str = f"{request.pitch:+d}Hz"
121
  rate_str = f"{request.rate:+d}%"
122
 
123
+ audio_path = await generate_tts(
124
  request.text,
125
+ request.voice,
126
  rate_str,
127
+ pitch_str
 
 
 
 
128
  )
129
 
130
  state.active_jobs[job_id].update({
131
  "status": "completed",
132
  "progress": 1.0,
133
  "result": {
 
134
  "audio_path": audio_path
135
  }
136
  })
 
141
  })
142
 
143
  background_tasks.add_task(process_tts)
 
144
  return {"job_id": job_id, "status": "queued"}
145
 
146
  @app.get("/api/v1/status/{job_id}")
 
149
  raise HTTPException(status_code=404, detail="Job not found")
150
  return state.active_jobs[job_id]
151
 
152
+ @app.get("/api/v1/download/{job_id}")
153
+ async def download_file(job_id: str):
154
  if job_id not in state.active_jobs:
155
  raise HTTPException(status_code=404, detail="Job not found")
156
 
 
158
  if job["status"] != "completed":
159
  raise HTTPException(status_code=400, detail="Job not completed")
160
 
161
+ file_path = job["result"]["audio_path"]
 
 
 
162
  return FileResponse(
163
  file_path,
164
+ filename=f"tts_output.mp3"
165
  )
166
 
167
  @app.get("/api/v1/health", response_model=HealthResponse)
 
174
  "active_jobs": len(state.active_jobs)
175
  }
176
 
177
+ @app.get("/api/v1/voices")
178
+ async def list_voices():
179
+ return {"voices": voice_options}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
 
181
  if __name__ == "__main__":
182
  uvicorn.run("fastapi_app:app", host="0.0.0.0", port=8000, reload=True)