Rivalcoder commited on
Commit
e2ce4a5
·
1 Parent(s): 5b04c03
Files changed (1) hide show
  1. app.py +22 -20
app.py CHANGED
@@ -4,6 +4,7 @@ import warnings
4
  import io
5
  import tempfile
6
  from pathlib import Path
 
7
 
8
  warnings.filterwarnings('ignore')
9
  os.environ['PYTHONWARNINGS'] = 'ignore'
@@ -68,13 +69,24 @@ app.add_middleware(
68
  allow_headers=["*"],
69
  )
70
 
71
- # Global variables for models
 
 
 
 
 
 
 
 
 
 
 
 
72
  pipeline = None
73
  whisper_model = None
74
 
75
  @app.on_event("startup")
76
  async def load_models():
77
- """Load models on startup"""
78
  global pipeline, whisper_model
79
 
80
  print(f"Using device: {device}")
@@ -90,10 +102,10 @@ async def load_models():
90
  print("Loading Whisper small model...")
91
  with SuppressStderr():
92
  whisper_model = whisper.load_model("small", device=device)
 
93
  print("Models loaded successfully!\n")
94
 
95
  def process_audio(audio_path):
96
- """Process audio file with diarization and transcription"""
97
  if not os.path.exists(audio_path):
98
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
99
 
@@ -146,7 +158,6 @@ def process_audio(audio_path):
146
 
147
  @app.get("/")
148
  async def root():
149
- """Root endpoint with API information"""
150
  return {
151
  "message": "Speaker Diarization & Transcription API",
152
  "version": "1.0.0",
@@ -159,7 +170,6 @@ async def root():
159
 
160
  @app.get("/health")
161
  async def health_check():
162
- """Health check endpoint"""
163
  return {
164
  "status": "healthy",
165
  "device": str(device),
@@ -168,19 +178,9 @@ async def health_check():
168
 
169
  @app.post("/process")
170
  async def process_audio_endpoint(file: UploadFile = File(...)):
171
- """
172
- Process audio file for speaker diarization and transcription
173
-
174
- Args:
175
- file: Audio file (wav, mp3, etc.)
176
-
177
- Returns:
178
- JSON response with segments and full transcription
179
- """
180
  if pipeline is None or whisper_model is None:
181
  raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.")
182
 
183
- # Validate file type
184
  allowed_extensions = {'.wav', '.mp3', '.m4a', '.flac', '.ogg', '.webm'}
185
  file_ext = Path(file.filename).suffix.lower()
186
 
@@ -190,15 +190,17 @@ async def process_audio_endpoint(file: UploadFile = File(...)):
190
  detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}"
191
  )
192
 
193
- # Save uploaded file temporarily
194
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
195
  try:
196
  content = await file.read()
197
  tmp_file.write(content)
198
  tmp_file_path = tmp_file.name
199
 
200
- # Process audio
201
- result = process_audio(tmp_file_path)
 
 
 
202
 
203
  return JSONResponse(content=result)
204
 
@@ -206,11 +208,11 @@ async def process_audio_endpoint(file: UploadFile = File(...)):
206
  raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
207
 
208
  finally:
209
- # Clean up temporary file
210
  if os.path.exists(tmp_file_path):
211
  os.unlink(tmp_file_path)
 
 
212
 
213
  if __name__ == "__main__":
214
  import uvicorn
215
  uvicorn.run(app, host="0.0.0.0", port=7860)
216
-
 
4
  import io
5
  import tempfile
6
  from pathlib import Path
7
+ import subprocess # <-- Added
8
 
9
  warnings.filterwarnings('ignore')
10
  os.environ['PYTHONWARNINGS'] = 'ignore'
 
69
  allow_headers=["*"],
70
  )
71
 
72
+ # Convert ANY audio file to WAV using FFmpeg
73
+ def convert_to_wav(input_path):
74
+ output_path = input_path + "_converted.wav"
75
+ command = [
76
+ "ffmpeg", "-y", "-i", input_path,
77
+ "-ac", "1",
78
+ "-ar", "16000",
79
+ output_path
80
+ ]
81
+ subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
82
+ return output_path
83
+
84
+ # Global variables
85
  pipeline = None
86
  whisper_model = None
87
 
88
  @app.on_event("startup")
89
  async def load_models():
 
90
  global pipeline, whisper_model
91
 
92
  print(f"Using device: {device}")
 
102
  print("Loading Whisper small model...")
103
  with SuppressStderr():
104
  whisper_model = whisper.load_model("small", device=device)
105
+
106
  print("Models loaded successfully!\n")
107
 
108
  def process_audio(audio_path):
 
109
  if not os.path.exists(audio_path):
110
  raise FileNotFoundError(f"Audio file not found: {audio_path}")
111
 
 
158
 
159
  @app.get("/")
160
  async def root():
 
161
  return {
162
  "message": "Speaker Diarization & Transcription API",
163
  "version": "1.0.0",
 
170
 
171
  @app.get("/health")
172
  async def health_check():
 
173
  return {
174
  "status": "healthy",
175
  "device": str(device),
 
178
 
179
  @app.post("/process")
180
  async def process_audio_endpoint(file: UploadFile = File(...)):
 
 
 
 
 
 
 
 
 
181
  if pipeline is None or whisper_model is None:
182
  raise HTTPException(status_code=503, detail="Models are still loading. Please try again in a moment.")
183
 
 
184
  allowed_extensions = {'.wav', '.mp3', '.m4a', '.flac', '.ogg', '.webm'}
185
  file_ext = Path(file.filename).suffix.lower()
186
 
 
190
  detail=f"Unsupported file type. Allowed: {', '.join(allowed_extensions)}"
191
  )
192
 
 
193
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
194
  try:
195
  content = await file.read()
196
  tmp_file.write(content)
197
  tmp_file_path = tmp_file.name
198
 
199
+ # Convert ANY format to WAV
200
+ wav_path = convert_to_wav(tmp_file_path)
201
+
202
+ # Process WAV only
203
+ result = process_audio(wav_path)
204
 
205
  return JSONResponse(content=result)
206
 
 
208
  raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
209
 
210
  finally:
 
211
  if os.path.exists(tmp_file_path):
212
  os.unlink(tmp_file_path)
213
+ if os.path.exists(tmp_file_path + "_converted.wav"):
214
+ os.unlink(tmp_file_path + "_converted.wav")
215
 
216
  if __name__ == "__main__":
217
  import uvicorn
218
  uvicorn.run(app, host="0.0.0.0", port=7860)