cheesecz commited on
Commit
ceb7caa
·
verified ·
1 Parent(s): 61c6e8a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+ import asyncio
5
+ from typing import List, Dict, Any, Optional
6
+ from concurrent.futures import ThreadPoolExecutor
7
+
8
+ import torch
9
+ from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel
12
+ import uvicorn
13
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
14
+ import librosa
15
+ import numpy as np
16
+ from fastapi.responses import JSONResponse
17
+ import gc
18
+
19
+ # Initialize thread pool for background processing
20
+ thread_pool = ThreadPoolExecutor(max_workers=2)
21
+
22
+ # Environment and model configuration
23
+ MODEL_NAME = "nyrahealth/CrisperWhisper"
24
+ BATCH_SIZE = 8
25
+ FILE_LIMIT_MB = 30
26
+ FILE_EXTENSIONS = [".mp3", ".wav", ".m4a", ".ogg", ".flac"]
27
+
28
+ # Initialize FastAPI app
29
+ app = FastAPI(
30
+ title="Speech to Text API",
31
+ description="API for transcribing audio files using the CrisperWhisper model",
32
+ version="1.0.0"
33
+ )
34
+
35
+ # Add CORS support
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ # Response models
45
+ class TranscriptionChunk(BaseModel):
46
+ timestamp: List[float]
47
+ text: str
48
+
49
+ class TranscriptionResponse(BaseModel):
50
+ text: str
51
+ chunks: List[TranscriptionChunk]
52
+
53
+ # Setup device and load model
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ print(f"Using device: {device}")
56
+
57
+ # Load model and processor at startup
58
+ @app.on_event("startup")
59
+ async def load_model():
60
+ global processor, model
61
+ print("Loading model and processor...")
62
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
63
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
64
+ model.to(device)
65
+ print("Model loaded successfully!")
66
+
67
+ def load_audio(file_path: str) -> tuple:
68
+ """Load audio file efficiently"""
69
+ try:
70
+ # Use a faster sr=None first to get the original sampling rate,
71
+ # then resample only if needed
72
+ audio_array, orig_sr = librosa.load(file_path, sr=None, mono=True)
73
+
74
+ # Resample only if needed
75
+ if orig_sr != 16000:
76
+ audio_array = librosa.resample(audio_array, orig_sr=orig_sr, target_sr=16000)
77
+ sampling_rate = 16000
78
+ else:
79
+ sampling_rate = orig_sr
80
+
81
+ # Convert to float32 if needed
82
+ if audio_array.dtype != np.float32:
83
+ audio_array = audio_array.astype(np.float32)
84
+
85
+ return audio_array, sampling_rate
86
+
87
+ except Exception as e:
88
+ print(f"Error loading audio: {str(e)}")
89
+ raise HTTPException(status_code=500, detail=f"Error loading audio: {str(e)}")
90
+
91
+ def process_audio_file(file_path: str) -> Dict:
92
+ """Process audio file and return transcription with timestamps"""
93
+ try:
94
+ # Load audio file efficiently
95
+ audio_array, sampling_rate = load_audio(file_path)
96
+
97
+ # Process with model
98
+ inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt")
99
+ inputs = {key: value.to(device) for key, value in inputs.items()}
100
+
101
+ # Generate transcription with word timestamps
102
+ with torch.no_grad():
103
+ outputs = model.generate(
104
+ **inputs,
105
+ return_timestamps=True,
106
+ return_dict_in_generate=True,
107
+ output_scores=True,
108
+ max_new_tokens=256 if len(audio_array) < 160000 else 512, # Adjust based on audio length
109
+ num_beams=1, # Use greedy decoding for speed
110
+ )
111
+
112
+ # Extract timestamps and words
113
+ result = processor.decode(outputs.sequences[0], skip_special_tokens=False, output_word_offsets=True)
114
+ words_with_timestamps = []
115
+
116
+ for word in result.word_offsets:
117
+ words_with_timestamps.append({
118
+ "text": word["word"].strip(),
119
+ "timestamp": [
120
+ round(word["start_offset"] / sampling_rate, 2),
121
+ round(word["end_offset"] / sampling_rate, 2)
122
+ ]
123
+ })
124
+
125
+ # Create final response format
126
+ response_data = {
127
+ "text": processor.decode(outputs.sequences[0], skip_special_tokens=True),
128
+ "chunks": words_with_timestamps
129
+ }
130
+
131
+ # Manual garbage collection to free memory
132
+ del inputs, outputs, result
133
+ if device == "cuda":
134
+ torch.cuda.empty_cache()
135
+ gc.collect()
136
+
137
+ return response_data
138
+
139
+ except Exception as e:
140
+ print(f"Error processing audio: {str(e)}")
141
+ raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
142
+
143
+ async def process_in_background(file_path: str):
144
+ """Process audio file in a background thread to prevent blocking"""
145
+ loop = asyncio.get_event_loop()
146
+ return await loop.run_in_executor(thread_pool, process_audio_file, file_path)
147
+
148
+ @app.post("/transcribe", response_model=TranscriptionResponse)
149
+ async def transcribe_audio(file: UploadFile = File(...)):
150
+ """
151
+ Transcribe an audio file to text with timestamps for each word.
152
+
153
+ Accepts .mp3, .wav, .m4a, .ogg or .flac files up to 30MB.
154
+ """
155
+ start_time = time.time()
156
+
157
+ # Validate file extension
158
+ file_ext = os.path.splitext(file.filename)[1].lower()
159
+ if file_ext not in FILE_EXTENSIONS:
160
+ raise HTTPException(
161
+ status_code=400,
162
+ detail=f"Unsupported file format. Supported formats: {', '.join(FILE_EXTENSIONS)}"
163
+ )
164
+
165
+ # Create temp file to store upload
166
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
167
+ # Get file content
168
+ content = await file.read()
169
+
170
+ # Check file size
171
+ if len(content) > FILE_LIMIT_MB * 1024 * 1024:
172
+ raise HTTPException(
173
+ status_code=400,
174
+ detail=f"File too large. Maximum size: {FILE_LIMIT_MB}MB"
175
+ )
176
+
177
+ # Write to temp file
178
+ temp_file.write(content)
179
+ temp_file_path = temp_file.name
180
+
181
+ try:
182
+ # Process the audio file in background to prevent blocking
183
+ result = await process_in_background(temp_file_path)
184
+ processing_time = time.time() - start_time
185
+ print(f"Processing completed in {processing_time:.2f} seconds")
186
+
187
+ return JSONResponse(content=result)
188
+
189
+ finally:
190
+ # Clean up the temp file
191
+ if os.path.exists(temp_file_path):
192
+ try:
193
+ os.unlink(temp_file_path)
194
+ except Exception as e:
195
+ print(f"Error deleting temp file: {e}")
196
+
197
+ @app.get("/health")
198
+ async def health_check():
199
+ """Health check endpoint"""
200
+ return {"status": "healthy"}
201
+
202
+ # Simple root endpoint that shows API is running
203
+ @app.get("/")
204
+ async def root():
205
+ return {
206
+ "message": "Speech-to-Text API is running",
207
+ "endpoints": {
208
+ "transcribe": "/transcribe (POST)",
209
+ "health": "/health (GET)",
210
+ "docs": "/docs (GET)"
211
+ },
212
+ "model": MODEL_NAME,
213
+ "device": device
214
+ }
215
+
216
+ if __name__ == "__main__":
217
+ port = int(os.environ.get("PORT", 7860))
218
+ uvicorn.run("app:app", host="0.0.0.0", port=port)