|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Query |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import whisper |
|
|
import torch |
|
|
import tempfile |
|
|
import os |
|
|
import uvicorn |
|
|
import logging |
|
|
import hashlib |
|
|
import json |
|
|
import sqlite3 |
|
|
from datetime import datetime, timedelta |
|
|
from typing import Optional, Dict, Any |
|
|
from contextlib import asynccontextmanager |
|
|
import asyncio |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler("app.log"), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
whisper_model = None |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
executor = None |
|
|
processing_semaphore = None |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Lifespan manager for startup and shutdown events""" |
|
|
|
|
|
global whisper_model, executor, processing_semaphore |
|
|
|
|
|
try: |
|
|
logger.info(f"Loading Whisper model on {device}") |
|
|
whisper_model = whisper.load_model("large-v3", device=device) |
|
|
logger.info("Whisper model loaded successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load Whisper model: {e}") |
|
|
raise RuntimeError(f"Whisper model loading failed: {e}") |
|
|
|
|
|
|
|
|
max_workers = min(4, (os.cpu_count() or 1)) |
|
|
executor = ThreadPoolExecutor(max_workers=max_workers) |
|
|
logger.info(f"ThreadPoolExecutor initialized with {max_workers} workers") |
|
|
|
|
|
|
|
|
max_concurrent = 5 if device == "cuda" else 3 |
|
|
processing_semaphore = asyncio.Semaphore(max_concurrent) |
|
|
logger.info(f"Processing semaphore set to {max_concurrent} concurrent operations") |
|
|
|
|
|
|
|
|
init_cache_db() |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
logger.info("Shutting down application...") |
|
|
if executor: |
|
|
executor.shutdown(wait=True) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
app = FastAPI( |
|
|
title="Whisper Transcription API", |
|
|
description="Scalable API for audio transcription using OpenAI Whisper", |
|
|
version="2.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["GET", "POST", "DELETE"], |
|
|
allow_headers=["*"], |
|
|
max_age=3600, |
|
|
) |
|
|
|
|
|
class DatabaseManager: |
|
|
"""Database management class with connection pooling""" |
|
|
|
|
|
def __init__(self, db_path: str = 'transcription_cache.db'): |
|
|
self.db_path = db_path |
|
|
self._init_db() |
|
|
|
|
|
def _init_db(self): |
|
|
"""Initialize database tables""" |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
cursor.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS cache ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
file_hash TEXT UNIQUE, |
|
|
filename TEXT, |
|
|
file_size INTEGER, |
|
|
transcription TEXT, |
|
|
language TEXT, |
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
last_accessed TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
''') |
|
|
|
|
|
|
|
|
cursor.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS processing_status ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
file_hash TEXT UNIQUE, |
|
|
filename TEXT, |
|
|
file_size INTEGER, |
|
|
status TEXT DEFAULT 'processing', |
|
|
progress INTEGER DEFAULT 0, |
|
|
estimated_time INTEGER DEFAULT 0, |
|
|
started_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
''') |
|
|
|
|
|
|
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_cache_hash ON cache(file_hash)') |
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_cache_created ON cache(created_at)') |
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status_hash ON processing_status(file_hash)') |
|
|
|
|
|
conn.commit() |
|
|
|
|
|
def get_connection(self): |
|
|
"""Get database connection""" |
|
|
return sqlite3.connect(self.db_path) |
|
|
|
|
|
|
|
|
db_manager = DatabaseManager() |
|
|
|
|
|
def init_cache_db(): |
|
|
"""Initialize cache database""" |
|
|
pass |
|
|
|
|
|
def calculate_file_hash(content: bytes, filename: str, file_size: int) -> str: |
|
|
"""Calculate hash for file identification""" |
|
|
hash_input = f"{filename}_{file_size}_{len(content)}" |
|
|
return hashlib.md5(content[:1024] + content[-1024:] + hash_input.encode()).hexdigest() |
|
|
|
|
|
def estimate_processing_time(file_size_mb: float) -> int: |
|
|
"""Estimate processing time in minutes - more conservative for scalability""" |
|
|
base_time = file_size_mb * 0.8 |
|
|
|
|
|
concurrent_penalty = min(2, file_size_mb * 0.1) |
|
|
estimated_seconds = base_time + concurrent_penalty |
|
|
return max(1, int(estimated_seconds / 60)) |
|
|
|
|
|
async def get_from_cache(file_hash: str) -> Optional[Dict[str, Any]]: |
|
|
"""Get transcription from cache""" |
|
|
try: |
|
|
with db_manager.get_connection() as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
'SELECT transcription FROM cache WHERE file_hash = ?', |
|
|
(file_hash,) |
|
|
) |
|
|
result = cursor.fetchone() |
|
|
|
|
|
if result: |
|
|
cursor.execute( |
|
|
'UPDATE cache SET last_accessed = CURRENT_TIMESTAMP WHERE file_hash = ?', |
|
|
(file_hash,) |
|
|
) |
|
|
conn.commit() |
|
|
|
|
|
|
|
|
try: |
|
|
transcription_data = json.loads(result[0]) if result[0] else {} |
|
|
except: |
|
|
transcription_data = {"text": result[0] or ""} |
|
|
|
|
|
return transcription_data |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Error getting from cache: {e}") |
|
|
return None |
|
|
|
|
|
async def save_to_cache(file_hash: str, filename: str, file_size: int, transcription: str, language: str = None): |
|
|
"""Save transcription to cache""" |
|
|
try: |
|
|
with db_manager.get_connection() as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
'''INSERT OR REPLACE INTO cache |
|
|
(file_hash, filename, file_size, transcription, language) |
|
|
VALUES (?, ?, ?, ?, ?)''', |
|
|
(file_hash, filename, file_size, transcription, language) |
|
|
) |
|
|
conn.commit() |
|
|
except Exception as e: |
|
|
logger.error(f"Error saving to cache: {e}") |
|
|
|
|
|
async def get_processing_status(file_hash: str) -> Optional[Dict[str, Any]]: |
|
|
"""Get processing status for a file""" |
|
|
try: |
|
|
with db_manager.get_connection() as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
'''SELECT status, progress, estimated_time, |
|
|
(julianday('now') - julianday(started_at)) * 24 * 60 as elapsed_minutes |
|
|
FROM processing_status WHERE file_hash = ?''', |
|
|
(file_hash,) |
|
|
) |
|
|
result = cursor.fetchone() |
|
|
|
|
|
if result: |
|
|
return { |
|
|
'status': result[0], |
|
|
'progress': result[1], |
|
|
'estimated_time': result[2], |
|
|
'elapsed_minutes': int(result[3] or 0) |
|
|
} |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Error getting processing status: {e}") |
|
|
return None |
|
|
|
|
|
async def update_processing_status(file_hash: str, status: str = None, progress: int = None, estimated_time: int = None): |
|
|
"""Update processing status""" |
|
|
try: |
|
|
updates = [] |
|
|
params = [] |
|
|
|
|
|
if status: |
|
|
updates.append("status = ?") |
|
|
params.append(status) |
|
|
if progress is not None: |
|
|
updates.append("progress = ?") |
|
|
params.append(progress) |
|
|
if estimated_time is not None: |
|
|
updates.append("estimated_time = ?") |
|
|
params.append(estimated_time) |
|
|
|
|
|
updates.append("updated_at = CURRENT_TIMESTAMP") |
|
|
params.append(file_hash) |
|
|
|
|
|
query = f"UPDATE processing_status SET {', '.join(updates)} WHERE file_hash = ?" |
|
|
|
|
|
with db_manager.get_connection() as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(query, params) |
|
|
conn.commit() |
|
|
except Exception as e: |
|
|
logger.error(f"Error updating status: {e}") |
|
|
|
|
|
async def add_processing_status(file_hash: str, filename: str, file_size: int, estimated_time: int): |
|
|
"""Add new processing status entry""" |
|
|
try: |
|
|
with db_manager.get_connection() as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
'''INSERT OR REPLACE INTO processing_status |
|
|
(file_hash, filename, file_size, status, progress, estimated_time) |
|
|
VALUES (?, ?, ?, 'processing', 0, ?)''', |
|
|
(file_hash, filename, file_size, estimated_time) |
|
|
) |
|
|
conn.commit() |
|
|
except Exception as e: |
|
|
logger.error(f"Error adding processing status: {e}") |
|
|
|
|
|
async def remove_processing_status(file_hash: str): |
|
|
"""Remove processing status entry""" |
|
|
try: |
|
|
with db_manager.get_connection() as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute( |
|
|
'DELETE FROM processing_status WHERE file_hash = ?', |
|
|
(file_hash,) |
|
|
) |
|
|
conn.commit() |
|
|
except Exception as e: |
|
|
logger.error(f"Error removing processing status: {e}") |
|
|
|
|
|
async def background_transcription(file_path: str, file_hash: str, filename: str, file_size: int): |
|
|
"""Background task for transcription with concurrency control""" |
|
|
async with processing_semaphore: |
|
|
try: |
|
|
logger.info(f"Starting background transcription for {filename}") |
|
|
|
|
|
await update_processing_status(file_hash, status='processing', progress=10) |
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
result = await loop.run_in_executor( |
|
|
executor, |
|
|
lambda: whisper_model.transcribe( |
|
|
file_path, |
|
|
fp16=(device != "cpu"), |
|
|
language=None, |
|
|
task="transcribe", |
|
|
verbose=False, |
|
|
word_timestamps=False |
|
|
) |
|
|
) |
|
|
|
|
|
await update_processing_status(file_hash, progress=90) |
|
|
|
|
|
text = result["text"].strip() or "No text detected" |
|
|
detected_language = result.get("language", "unknown") |
|
|
|
|
|
response_data = { |
|
|
"text": text, |
|
|
"language": detected_language, |
|
|
"from_cache": False |
|
|
} |
|
|
|
|
|
|
|
|
await save_to_cache( |
|
|
file_hash, filename, file_size, |
|
|
json.dumps(response_data), detected_language |
|
|
) |
|
|
|
|
|
await update_processing_status(file_hash, status='completed', progress=100) |
|
|
logger.info(f"Background transcription completed for {filename}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in background transcription: {e}") |
|
|
await update_processing_status(file_hash, status='error', progress=0) |
|
|
|
|
|
finally: |
|
|
|
|
|
try: |
|
|
if os.path.exists(file_path): |
|
|
os.unlink(file_path) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
except Exception as e: |
|
|
logger.error(f"Error in cleanup: {e}") |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint with system information""" |
|
|
try: |
|
|
with db_manager.get_connection() as conn: |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute('SELECT COUNT(*) FROM cache') |
|
|
cache_count = cursor.fetchone()[0] or 0 |
|
|
|
|
|
cursor.execute('SELECT COUNT(*) FROM processing_status WHERE status = "processing"') |
|
|
processing_count = cursor.fetchone()[0] or 0 |
|
|
|
|
|
return { |
|
|
"message": "Whisper API is running", |
|
|
"device": device, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"cached_files": cache_count, |
|
|
"currently_processing": processing_count |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Error in root endpoint: {e}") |
|
|
return {"error": "Unable to retrieve system information"} |
|
|
|
|
|
@app.post("/transcribe") |
|
|
async def transcribe_audio( |
|
|
background_tasks: BackgroundTasks, |
|
|
file: UploadFile = File(...), |
|
|
language: Optional[str] = Query(None, description="Specify language code for better accuracy") |
|
|
): |
|
|
"""Transcribe audio file to text with scalability optimizations""" |
|
|
tmp_file_path = None |
|
|
|
|
|
try: |
|
|
|
|
|
if not file or not file.filename: |
|
|
raise HTTPException(status_code=400, detail="No valid file provided") |
|
|
|
|
|
|
|
|
if not file.content_type or not ( |
|
|
file.content_type.startswith('audio/') or |
|
|
file.content_type.startswith('video/') or |
|
|
file.content_type == 'application/octet-stream' |
|
|
): |
|
|
logger.warning(f"Suspicious file type: {file.content_type}") |
|
|
|
|
|
logger.info(f"Received file: {file.filename}, size: {file.size}, type: {file.content_type}") |
|
|
|
|
|
|
|
|
try: |
|
|
contents = await file.read() |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Error reading file: {str(e)}") |
|
|
|
|
|
file_size = len(contents) |
|
|
file_size_mb = file_size / (1024 * 1024) |
|
|
|
|
|
logger.info(f"File size: {file_size} bytes ({file_size_mb:.1f} MB)") |
|
|
|
|
|
|
|
|
max_size = 50 * 1024 * 1024 |
|
|
if file_size > max_size: |
|
|
raise HTTPException(status_code=413, detail=f"File too large (max {max_size//1024//1024}MB)") |
|
|
|
|
|
if file_size == 0: |
|
|
raise HTTPException(status_code=400, detail="Empty file") |
|
|
|
|
|
|
|
|
file_hash = calculate_file_hash(contents, file.filename, file_size) |
|
|
logger.info(f"File hash: {file_hash}") |
|
|
|
|
|
|
|
|
cached_result = await get_from_cache(file_hash) |
|
|
if cached_result: |
|
|
logger.info("Cache hit - returning cached result") |
|
|
await remove_processing_status(file_hash) |
|
|
return JSONResponse(cached_result) |
|
|
|
|
|
|
|
|
processing_status = await get_processing_status(file_hash) |
|
|
if processing_status: |
|
|
logger.info("File is currently being processed") |
|
|
return JSONResponse({ |
|
|
"status": "processing", |
|
|
"progress": processing_status['progress'], |
|
|
"estimated_time": processing_status['estimated_time'], |
|
|
"elapsed_minutes": processing_status['elapsed_minutes'], |
|
|
"message": f"File is being processed. Estimated time remaining: {processing_status['estimated_time'] - processing_status['elapsed_minutes']} minutes" |
|
|
}) |
|
|
|
|
|
|
|
|
available_slots = processing_semaphore._value |
|
|
if available_slots == 0: |
|
|
|
|
|
logger.info("Server at capacity - queueing for background processing") |
|
|
estimated_time = estimate_processing_time(file_size_mb) + 2 |
|
|
|
|
|
|
|
|
file_ext = os.path.splitext(file.filename)[1].lower() or ".wav" |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file: |
|
|
tmp_file.write(contents) |
|
|
tmp_file_path = tmp_file.name |
|
|
|
|
|
await add_processing_status(file_hash, file.filename, file_size, estimated_time) |
|
|
|
|
|
background_tasks.add_task( |
|
|
background_transcription, |
|
|
tmp_file_path, file_hash, file.filename, file_size |
|
|
) |
|
|
|
|
|
return JSONResponse({ |
|
|
"status": "queued", |
|
|
"estimated_time": estimated_time, |
|
|
"file_hash": file_hash, |
|
|
"message": f"Server is busy. Your file has been queued. Estimated time: {estimated_time} minutes.", |
|
|
"queue_position": f"Processing capacity: {5 - available_slots}/5" |
|
|
}) |
|
|
|
|
|
logger.info(f"Starting new processing... Available slots: {available_slots}") |
|
|
|
|
|
|
|
|
file_ext = os.path.splitext(file.filename)[1].lower() or ".wav" |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file: |
|
|
tmp_file.write(contents) |
|
|
tmp_file_path = tmp_file.name |
|
|
|
|
|
logger.info(f"Created temp file: {tmp_file_path}") |
|
|
|
|
|
|
|
|
estimated_time = estimate_processing_time(file_size_mb) |
|
|
|
|
|
|
|
|
if file_size_mb < 5: |
|
|
async with processing_semaphore: |
|
|
try: |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
result = await loop.run_in_executor( |
|
|
executor, |
|
|
lambda: whisper_model.transcribe( |
|
|
tmp_file_path, |
|
|
fp16=(device != "cpu"), |
|
|
language=language, |
|
|
task="transcribe", |
|
|
verbose=False, |
|
|
word_timestamps=False |
|
|
) |
|
|
) |
|
|
|
|
|
text = result["text"].strip() or "No text detected" |
|
|
detected_language = result.get("language", "unknown") |
|
|
|
|
|
response_data = { |
|
|
"text": text, |
|
|
"language": detected_language, |
|
|
"from_cache": False |
|
|
} |
|
|
|
|
|
|
|
|
await save_to_cache( |
|
|
file_hash, file.filename, file_size, |
|
|
json.dumps(response_data), detected_language |
|
|
) |
|
|
|
|
|
return JSONResponse(response_data) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in immediate transcription: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") |
|
|
finally: |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
else: |
|
|
|
|
|
await add_processing_status(file_hash, file.filename, file_size, estimated_time) |
|
|
|
|
|
background_tasks.add_task( |
|
|
background_transcription, |
|
|
tmp_file_path, file_hash, file.filename, file_size |
|
|
) |
|
|
|
|
|
return JSONResponse({ |
|
|
"status": "processing_started", |
|
|
"estimated_time": estimated_time, |
|
|
"file_hash": file_hash, |
|
|
"message": f"Processing started. Estimated time: {estimated_time} minutes.", |
|
|
"server_load": f"Processing slots: {5 - available_slots}/5" |
|
|
}) |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Error in transcription endpoint: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") |
|
|
|
|
|
finally: |
|
|
|
|
|
if tmp_file_path and os.path.exists(tmp_file_path) and file_size_mb < 5: |
|
|
try: |
|
|
os.unlink(tmp_file_path) |
|
|
except Exception as e: |
|
|
logger.error(f"Error deleting temp file: {e}") |
|
|
|
|
|
@app.get("/status/{file_hash}") |
|
|
async def check_status(file_hash: str): |
|
|
"""Check processing status for a file""" |
|
|
|
|
|
cached_result = await get_from_cache(file_hash) |
|
|
if cached_result: |
|
|
await remove_processing_status(file_hash) |
|
|
cached_result.update({ |
|
|
"status": "completed", |
|
|
"from_cache": True, |
|
|
"message": "Processing completed and result is ready" |
|
|
}) |
|
|
return JSONResponse(cached_result) |
|
|
|
|
|
|
|
|
processing_status = await get_processing_status(file_hash) |
|
|
if processing_status: |
|
|
remaining_time = max(0, processing_status['estimated_time'] - processing_status['elapsed_minutes']) |
|
|
return JSONResponse({ |
|
|
"status": processing_status['status'], |
|
|
"progress": processing_status['progress'], |
|
|
"elapsed_minutes": processing_status['elapsed_minutes'], |
|
|
"estimated_time": processing_status['estimated_time'], |
|
|
"remaining_time": remaining_time, |
|
|
"message": f"Processing... about {remaining_time} minutes remaining" |
|
|
}) |
|
|
|
|
|
return JSONResponse({ |
|
|
"status": "not_found", |
|
|
"message": "File not found in cache or processing queue" |
|
|
}, status_code=404) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"device": device, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"whisper_loaded": whisper_model is not None |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
timeout_keep_alive=300, |
|
|
limit_concurrency=100, |
|
|
limit_max_requests=1000, |
|
|
log_config=None, |
|
|
access_log=False |
|
|
) |