fred1012 commited on
Commit
224a8a5
Β·
verified Β·
1 Parent(s): 6ab5729

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +231 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import librosa
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from typing import Dict, Optional, List
8
+ from datetime import datetime
9
+
10
+ from fastapi import FastAPI, File, UploadFile, HTTPException
11
+ from pydantic import BaseModel
12
+ from transformers import pipeline
13
+
14
+ # --- Configuration ---
15
+ WHISPER_MODEL = os.getenv("WHISPER_MODEL", "small") # small, medium, large
16
+ WHISPER_PORT = int(os.getenv("WHISPER_PORT", 8000))
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
19
+
20
+ # Global model cache
21
+ _whisper_pipeline = None
22
+ _model_info = {
23
+ "model_name": WHISPER_MODEL,
24
+ "device": DEVICE,
25
+ "dtype": str(TORCH_DTYPE),
26
+ "cuda_available": torch.cuda.is_available()
27
+ }
28
+
29
+ # --- Models ---
30
+ class TranscriptionResponse(BaseModel):
31
+ text: str
32
+ language: str = "en"
33
+ confidence: Optional[float] = None
34
+ duration: float = 0.0
35
+ timestamp: str = ""
36
+
37
+ # --- Utility Functions ---
38
+ def get_whisper_pipeline():
39
+ """Get or initialize the Whisper pipeline (cached)."""
40
+ global _whisper_pipeline
41
+ if _whisper_pipeline is not None:
42
+ return _whisper_pipeline
43
+
44
+ print(f"πŸ”„ Loading Whisper model: {WHISPER_MODEL} on {DEVICE} with dtype {TORCH_DTYPE}")
45
+
46
+ _whisper_pipeline = pipeline(
47
+ "automatic-speech-recognition",
48
+ model=f"openai/whisper-{WHISPER_MODEL}",
49
+ device=DEVICE,
50
+ torch_dtype=TORCH_DTYPE
51
+ )
52
+
53
+ print(f"βœ… Whisper model loaded successfully")
54
+ return _whisper_pipeline
55
+
56
+ def load_and_resample_audio(audio_path: str, target_sr: int = 16000) -> tuple:
57
+ """Load audio file and resample to 16kHz (required by Whisper)."""
58
+ try:
59
+ # Load audio file with librosa (no ffmpeg needed)
60
+ audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
61
+ duration = librosa.get_duration(y=audio, sr=sr)
62
+ print(f"πŸ“ Loaded audio: {Path(audio_path).name} | Duration: {duration:.2f}s | SR: {sr}Hz")
63
+ return audio, sr, duration
64
+ except Exception as e:
65
+ print(f"❌ Error loading audio: {e}")
66
+ raise
67
+
68
+ async def transcribe_audio(audio_path: str) -> Dict:
69
+ """Transcribe audio file using Whisper."""
70
+ try:
71
+ # Load audio
72
+ audio, sr, duration = load_and_resample_audio(audio_path)
73
+
74
+ # Get pipeline
75
+ pipeline_model = get_whisper_pipeline()
76
+
77
+ print(f"🎀 Transcribing {Path(audio_path).name}...")
78
+
79
+ # Transcribe
80
+ result = pipeline_model(
81
+ audio,
82
+ chunk_length_s=30,
83
+ stride_length_s=(4, 2),
84
+ batch_size=24 if torch.cuda.is_available() else 4
85
+ )
86
+
87
+ print(f"βœ… Transcription complete")
88
+
89
+ return {
90
+ "text": result.get("text", "").strip(),
91
+ "language": "en", # Whisper doesn't return language detection reliably
92
+ "confidence": None, # Whisper doesn't provide per-segment confidence
93
+ "duration": duration,
94
+ "timestamp": datetime.now().isoformat()
95
+ }
96
+
97
+ except Exception as e:
98
+ print(f"❌ Transcription error: {e}")
99
+ raise
100
+
101
+ # --- FastAPI App ---
102
+ app = FastAPI(
103
+ title="Whisper Transcription Server",
104
+ description="FastAPI server for audio transcription using OpenAI Whisper",
105
+ version="1.0.0"
106
+ )
107
+
108
+ @app.on_event("startup")
109
+ async def startup():
110
+ print(f"πŸš€ Whisper Server starting on port {WHISPER_PORT}")
111
+ print(f"πŸ“Š Configuration:")
112
+ print(f" - Model: {WHISPER_MODEL}")
113
+ print(f" - Device: {DEVICE}")
114
+ print(f" - CUDA Available: {torch.cuda.is_available()}")
115
+ print(f" - Torch Dtype: {TORCH_DTYPE}")
116
+
117
+ # Pre-load model
118
+ get_whisper_pipeline()
119
+
120
+ @app.get("/health")
121
+ async def health_check():
122
+ """Check server health and model status."""
123
+ return {
124
+ "status": "healthy",
125
+ "model_info": _model_info,
126
+ "cuda_available": torch.cuda.is_available(),
127
+ "device": DEVICE
128
+ }
129
+
130
+ @app.get("/")
131
+ async def root():
132
+ """Root endpoint with server info."""
133
+ return {
134
+ "server": "Whisper Transcription Backend",
135
+ "model": WHISPER_MODEL,
136
+ "device": DEVICE,
137
+ "endpoints": {
138
+ "/health": "Server health check",
139
+ "/transcribe": "POST - Transcribe audio file",
140
+ "/transcribe_file": "POST - Alternative transcribe endpoint"
141
+ }
142
+ }
143
+
144
+ @app.post("/transcribe")
145
+ async def transcribe(file: UploadFile = File(...)):
146
+ """
147
+ Transcribe an uploaded audio file.
148
+ Accepts: mp3, wav, m4a, flac, ogg, aac
149
+ """
150
+ if not file.filename:
151
+ raise HTTPException(status_code=400, detail="No file provided")
152
+
153
+ # Check file extension
154
+ allowed_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.ogg', '.aac'}
155
+ file_ext = Path(file.filename).suffix.lower()
156
+
157
+ if file_ext not in allowed_extensions:
158
+ raise HTTPException(
159
+ status_code=400,
160
+ detail=f"Unsupported file format: {file_ext}. Allowed: {allowed_extensions}"
161
+ )
162
+
163
+ temp_file = None
164
+ try:
165
+ # Save uploaded file temporarily
166
+ temp_path = Path(f"temp_{file.filename}")
167
+ with open(temp_path, 'wb') as f:
168
+ content = await file.read()
169
+ f.write(content)
170
+
171
+ temp_file = temp_path
172
+
173
+ print(f"πŸ“€ Processing uploaded file: {file.filename} ({len(content)} bytes)")
174
+
175
+ # Transcribe
176
+ result = await transcribe_audio(str(temp_path))
177
+
178
+ return {
179
+ "audio_file": file.filename,
180
+ "text": result["text"],
181
+ "language": result["language"],
182
+ "duration": result["duration"],
183
+ "timestamp": result["timestamp"]
184
+ }
185
+
186
+ except Exception as e:
187
+ print(f"❌ Transcription failed: {e}")
188
+ raise HTTPException(status_code=500, detail=str(e))
189
+
190
+ finally:
191
+ # Cleanup
192
+ if temp_file and temp_file.exists():
193
+ temp_file.unlink()
194
+ print(f"🧹 Cleaned up temp file: {temp_file}")
195
+
196
+ @app.post("/transcribe_file")
197
+ async def transcribe_file(file: UploadFile = File(...)):
198
+ """Alternative endpoint name for transcription."""
199
+ return await transcribe(file)
200
+
201
+ @app.post("/transcribe_batch")
202
+ async def transcribe_batch(files: List[UploadFile] = File(...)):
203
+ """
204
+ Transcribe multiple audio files in parallel.
205
+ """
206
+ if not files:
207
+ raise HTTPException(status_code=400, detail="No files provided")
208
+
209
+ results = []
210
+ for file in files:
211
+ try:
212
+ result = await transcribe(file)
213
+ results.append({
214
+ "status": "success",
215
+ "data": result
216
+ })
217
+ except Exception as e:
218
+ results.append({
219
+ "status": "error",
220
+ "filename": file.filename,
221
+ "error": str(e)
222
+ })
223
+
224
+ return {
225
+ "total": len(files),
226
+ "results": results
227
+ }
228
+
229
+ if __name__ == "__main__":
230
+ import uvicorn
231
+ uvicorn.run(app, host="0.0.0.0", port=WHISPER_PORT)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.48.0
2
+ timm
3
+ einops
4
+ pillow
5
+ hf_transfer
6
+ torch