drrobot9 commited on
Commit
580e60e
·
verified ·
1 Parent(s): 0b0e60e

Initial commit

Browse files
Files changed (4) hide show
  1. Dockerfile +24 -0
  2. app/config.json +15 -0
  3. app/main.py +413 -0
  4. requirements.txt +13 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Install comprehensive audio/video libraries
4
+ RUN apt-get update && apt-get install -y \
5
+ ffmpeg \
6
+ sox \
7
+ libsndfile1 \
8
+ libavcodec-extra \
9
+ libavformat-dev \
10
+ libavutil-dev \
11
+ libavdevice-dev \
12
+ libgl1 \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ WORKDIR /app
16
+
17
+ COPY requirements.txt .
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ COPY . .
21
+
22
+ EXPOSE 7860
23
+
24
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
app/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "name": "facebook/mms-1b-all",
4
+ "device": "cuda"
5
+
6
+ },
7
+ "api": {
8
+ "text_model_url": "https://remostart-farmlingua-ai-conversational.hf.space/ask",
9
+ "timeout_sec": 3000
10
+ },
11
+ "server": {
12
+ "host": "0.0.0.0",
13
+ "port": 7860
14
+ }
15
+ }
app/main.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import json
3
+ import torch
4
+ import torchaudio
5
+ import requests
6
+ import numpy as np
7
+ import tempfile
8
+ import os
9
+ import logging
10
+ from typing import Optional, Dict, Any
11
+
12
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Request
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import JSONResponse
15
+ from transformers import AutoProcessor, AutoModelForCTC
16
+ from pydantic import BaseModel
17
+ from pydantic_settings import BaseSettings
18
+ from pydub import AudioSegment
19
+ from contextlib import asynccontextmanager
20
+
21
+
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+
30
+ class Settings(BaseSettings):
31
+
32
+ model_name: str = "facebook/mms-1b-all"
33
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+
36
+ text_model_url: str = "https://remostart-farmlingua-ai-conversational.hf.space/ask"
37
+ timeout_sec: int = 3000
38
+
39
+ # Audio processing settings
40
+ sample_rate: int = 16000
41
+ max_audio_seconds: int = 300
42
+ chunk_seconds: int = 20
43
+ overlap_seconds: int = 2
44
+
45
+
46
+ host: str = "0.0.0.0"
47
+ port: int = 7860
48
+ workers: int = 1
49
+
50
+ # CORS settings
51
+ cors_origins: list = ["*"]
52
+ cors_methods: list = ["*"]
53
+ cors_headers: list = ["*"]
54
+
55
+ class Config:
56
+ env_file = ".env"
57
+ env_prefix = "STT_"
58
+
59
+
60
+
61
+ settings = Settings()
62
+
63
+ # Calculate derived constants
64
+ MAX_SAMPLES = settings.sample_rate * settings.max_audio_seconds
65
+ CHUNK_SIZE = settings.chunk_seconds * settings.sample_rate
66
+ OVERLAP = settings.overlap_seconds * settings.sample_rate
67
+ STEP = CHUNK_SIZE - OVERLAP
68
+
69
+
70
+
71
+ @asynccontextmanager
72
+ async def lifespan(app: FastAPI):
73
+ # Startup
74
+ logger.info(f"Starting STT service with device: {settings.device}")
75
+ logger.info(f"Loading model: {settings.model_name}")
76
+
77
+ try:
78
+ # Initialize processor and model
79
+ app.state.processor = AutoProcessor.from_pretrained(settings.model_name)
80
+ app.state.model = AutoModelForCTC.from_pretrained(settings.model_name).to(settings.device)
81
+ app.state.model.eval()
82
+ logger.info("Model loaded successfully")
83
+ except Exception as e:
84
+ logger.error(f"Failed to load model: {str(e)}")
85
+ raise
86
+
87
+ yield
88
+
89
+
90
+ logger.info("Shutting down STT service")
91
+ if hasattr(app.state, 'model'):
92
+ del app.state.model
93
+ torch.cuda.empty_cache()
94
+
95
+
96
+
97
+ app = FastAPI(
98
+ title="Universal Audio STT",
99
+ version="1.5.0",
100
+ description="Speech-to-Text service with support for multiple audio formats",
101
+ lifespan=lifespan
102
+ )
103
+
104
+
105
+ app.add_middleware(
106
+ CORSMiddleware,
107
+ allow_origins=settings.cors_origins,
108
+ allow_credentials=True,
109
+ allow_methods=settings.cors_methods,
110
+ allow_headers=settings.cors_headers,
111
+ )
112
+
113
+
114
+ class STTResponse(BaseModel):
115
+ transcript: str
116
+ downstream_response: Optional[Dict[str, Any]] = None
117
+ error: Optional[str] = None
118
+ processing_time_ms: Optional[float] = None
119
+
120
+
121
+ class HealthResponse(BaseModel):
122
+ status: str
123
+ device: str
124
+ model: str
125
+ max_audio_seconds: int
126
+ uptime: Optional[float] = None
127
+
128
+
129
+
130
+ @app.exception_handler(Exception)
131
+ async def global_exception_handler(request: Request, exc: Exception):
132
+ logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
133
+ return JSONResponse(
134
+ status_code=500,
135
+ content={
136
+ "transcript": "",
137
+ "downstream_response": None,
138
+ "error": f"Internal server error: {str(exc)}"
139
+ }
140
+ )
141
+
142
+
143
+ def load_audio_safe(file_bytes: bytes) -> tuple[np.ndarray | None, str | None]:
144
+ """Load audio file using pydub (supports more formats) with torchaudio fallback."""
145
+ if not file_bytes:
146
+ return None, "Empty audio file"
147
+
148
+ if len(file_bytes) == 0:
149
+ return None, "Empty audio file"
150
+
151
+
152
+ with tempfile.NamedTemporaryFile(suffix='.audio', delete=False) as tmp:
153
+ tmp.write(file_bytes)
154
+ tmp_path = tmp.name
155
+
156
+ try:
157
+
158
+ try:
159
+ audio = AudioSegment.from_file(tmp_path)
160
+
161
+
162
+ if len(audio) == 0:
163
+ return None, "Audio file is empty"
164
+
165
+
166
+ if audio.channels > 1:
167
+ audio = audio.set_channels(1)
168
+
169
+
170
+ if audio.frame_rate != settings.sample_rate:
171
+ audio = audio.set_frame_rate(settings.sample_rate)
172
+
173
+
174
+ samples = np.array(audio.get_array_of_samples()).astype(np.float32)
175
+
176
+
177
+ if audio.sample_width == 1: # 8-bit
178
+ samples = samples / 127.5 - 1.0
179
+ elif audio.sample_width == 2: # 16-bit
180
+ samples = samples / 32768.0
181
+ elif audio.sample_width == 3: # 24-bit
182
+ samples = samples / 8388608.0
183
+ elif audio.sample_width == 4: # 32-bit
184
+ samples = samples / 2147483648.0
185
+ else:
186
+ if len(samples) > 0:
187
+ max_val = np.max(np.abs(samples))
188
+ if max_val > 0:
189
+ samples = samples / max_val
190
+
191
+ except Exception as pydub_error:
192
+ logger.warning(f"Pydub failed, trying torchaudio: {pydub_error}")
193
+
194
+ try:
195
+ waveform, sr = torchaudio.load(tmp_path)
196
+
197
+ if waveform.numel() == 0:
198
+ return None, "Audio contains no samples"
199
+
200
+ waveform = waveform.mean(dim=0)
201
+
202
+ if sr != settings.sample_rate:
203
+ waveform = torchaudio.functional.resample(
204
+ waveform, orig_freq=sr, new_freq=settings.sample_rate
205
+ )
206
+
207
+ samples = waveform.numpy()
208
+
209
+ except Exception as torchaudio_error:
210
+ logger.error(f"Both pydub and torchaudio failed: {torchaudio_error}")
211
+ return None, f"Unsupported audio format. Supported formats: MP3, WAV, M4A, FLAC, OGG, etc."
212
+
213
+ except Exception as e:
214
+ logger.error(f"Failed to load audio: {str(e)}")
215
+ return None, f"Failed to process audio file: {str(e)}"
216
+ finally:
217
+
218
+ try:
219
+ os.unlink(tmp_path)
220
+ except:
221
+ pass
222
+
223
+ if len(samples) == 0:
224
+ return None, "Audio contains no samples"
225
+
226
+ if len(samples) > MAX_SAMPLES:
227
+ return None, f"Audio exceeds {settings.max_audio_seconds // 60} minute limit ({settings.max_audio_seconds} seconds)"
228
+
229
+ return samples, None
230
+
231
+
232
+ def chunk_audio(audio: np.ndarray):
233
+ """Split audio into overlapping chunks for processing."""
234
+ for start in range(0, len(audio), STEP):
235
+ chunk = audio[start:start + CHUNK_SIZE]
236
+ if len(chunk) < settings.sample_rate: # Less than 1 second
237
+ break
238
+ yield chunk
239
+
240
+
241
+ def transcribe_chunk(chunk: np.ndarray, processor, model) -> str:
242
+ """Transcribe a single audio chunk."""
243
+ inputs = processor(
244
+ chunk,
245
+ sampling_rate=settings.sample_rate,
246
+ return_tensors="pt",
247
+ padding=True,
248
+ )
249
+
250
+ with torch.no_grad():
251
+ logits = model(inputs.input_values.to(settings.device)).logits
252
+
253
+ predicted_ids = torch.argmax(logits, dim=-1)
254
+ return processor.batch_decode(predicted_ids)[0].strip()
255
+
256
+
257
+ def transcribe_long(audio: np.ndarray, processor, model) -> str:
258
+ """Transcribe long audio by processing in chunks."""
259
+ texts = []
260
+ for chunk in chunk_audio(audio):
261
+ text = transcribe_chunk(chunk, processor, model)
262
+ if text:
263
+ texts.append(text)
264
+ return " ".join(texts)
265
+
266
+
267
+ def forward_to_text_model(text: str) -> Optional[Dict[str, Any]]:
268
+ """Forward transcribed text to downstream text model."""
269
+ if not text or not text.strip():
270
+ return None
271
+
272
+ try:
273
+ response = requests.post(
274
+ settings.text_model_url,
275
+ json={"query": text},
276
+ timeout=settings.timeout_sec,
277
+ )
278
+ response.raise_for_status()
279
+ return response.json()
280
+ except requests.exceptions.Timeout:
281
+ logger.warning("Downstream text model timeout")
282
+ return None
283
+ except requests.exceptions.RequestException as e:
284
+ logger.warning(f"Downstream text model error: {str(e)}")
285
+ return None
286
+
287
+
288
+ @app.post("/stt", response_model=STTResponse)
289
+ async def stt(audio: UploadFile = File(...)):
290
+ """
291
+ Speech-to-Text endpoint.
292
+
293
+ Accepts audio files in various formats (MP3, WAV, M4A, FLAC, OGG, etc.)
294
+ and returns transcribed text.
295
+ """
296
+ import time
297
+ start_time = time.time()
298
+
299
+ # Validate file
300
+ if not audio.content_type or not audio.content_type.startswith('audio/'):
301
+ logger.warning(f"Invalid content type: {audio.content_type}")
302
+
303
+ try:
304
+ audio_bytes = await audio.read()
305
+ logger.info(f"Received audio file: {audio.filename}, size: {len(audio_bytes)} bytes")
306
+ except Exception as e:
307
+ logger.error(f"Failed to read audio file: {str(e)}")
308
+ raise HTTPException(status_code=400, detail="Failed to read audio file")
309
+
310
+ # Load audio
311
+ audio_data, error = load_audio_safe(audio_bytes)
312
+
313
+ if error:
314
+ logger.warning(f"Audio loading failed: {error}")
315
+ return STTResponse(
316
+ transcript="",
317
+ downstream_response=None,
318
+ error=error,
319
+ processing_time_ms=(time.time() - start_time) * 1000
320
+ )
321
+
322
+ # Transcribe
323
+ try:
324
+ transcript = transcribe_long(
325
+ audio_data,
326
+ app.state.processor,
327
+ app.state.model
328
+ )
329
+ logger.info(f"Transcription successful, length: {len(transcript)} chars")
330
+ except Exception as e:
331
+ logger.error(f"Transcription failed: {str(e)}")
332
+ return STTResponse(
333
+ transcript="",
334
+ downstream_response=None,
335
+ error=f"Transcription failed: {str(e)}",
336
+ processing_time_ms=(time.time() - start_time) * 1000
337
+ )
338
+
339
+
340
+ downstream = None
341
+ if transcript and transcript.strip():
342
+ try:
343
+ downstream = forward_to_text_model(transcript)
344
+ except Exception as e:
345
+ logger.warning(f"Downstream processing failed: {str(e)}")
346
+
347
+ processing_time_ms = (time.time() - start_time) * 1000
348
+ logger.info(f"Request completed in {processing_time_ms:.2f}ms")
349
+
350
+ return STTResponse(
351
+ transcript=transcript,
352
+ downstream_response=downstream,
353
+ error=None,
354
+ processing_time_ms=processing_time_ms
355
+ )
356
+
357
+
358
+ @app.get("/health", response_model=HealthResponse)
359
+ async def health_check():
360
+ """
361
+ Health check endpoint.
362
+
363
+ Returns service status and configuration.
364
+ """
365
+ import psutil
366
+ import time
367
+
368
+
369
+ uptime = time.time() - app.state.start_time if hasattr(app.state, 'start_time') else None
370
+
371
+ return HealthResponse(
372
+ status="healthy",
373
+ device=settings.device,
374
+ model=settings.model_name,
375
+ max_audio_seconds=settings.max_audio_seconds,
376
+ uptime=uptime
377
+ )
378
+
379
+
380
+ @app.get("/config")
381
+ async def get_config():
382
+ """
383
+ Get current configuration (excluding sensitive data).
384
+ """
385
+ return {
386
+ "model": settings.model_name,
387
+ "device": settings.device,
388
+ "sample_rate": settings.sample_rate,
389
+ "max_audio_seconds": settings.max_audio_seconds,
390
+ "chunk_seconds": settings.chunk_seconds,
391
+ "overlap_seconds": settings.overlap_seconds,
392
+ "cors_enabled": True
393
+ }
394
+
395
+
396
+
397
+ @app.on_event("startup")
398
+ async def startup_event():
399
+ app.state.start_time = time.time() if 'time' in locals() else None
400
+ logger.info("STT service started successfully")
401
+
402
+
403
+ if __name__ == "__main__":
404
+ import uvicorn
405
+
406
+ logger.info(f"Starting server on {settings.host}:{settings.port}")
407
+ uvicorn.run(
408
+ "main:app",
409
+ host=settings.host,
410
+ port=settings.port,
411
+ workers=settings.workers,
412
+ log_level="info"
413
+ )
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ transformers
5
+ soundfile
6
+ requests
7
+ python-multipart
8
+ torchaudio
9
+ pydub
10
+ numpy
11
+ pydantic[email]
12
+ pydantic-settings
13
+ psutil