Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uuid | |
| from . import processing | |
| from .models import TaskStatus | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| app = FastAPI() | |
| # Configure CORS - In production, restrict this to your frontend's domain | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| tasks = {} | |
| TEMP_DIR = Path("/tmp/temp_audio") | |
| TEMP_DIR.mkdir(exist_ok=True) | |
| ALLOWED_EXTENSIONS = {".wav", ".mp3", ".m4a", ".ogg", ".flac"} | |
| MAX_FILE_SIZE = 200 * 1024 * 1024 | |
| async def upload_audio_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): | |
| """ | |
| Receives an audio file, saves it temporarily to disk, and starts the | |
| processing pipeline in the background. Includes validation for file type and size. | |
| """ | |
| file_ext = Path(file.filename).suffix.lower() | |
| if file_ext not in ALLOWED_EXTENSIONS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid audio format. Allowed formats: {', '.join(ALLOWED_EXTENSIONS)}" | |
| ) | |
| file_content = await file.read() | |
| if len(file_content) > MAX_FILE_SIZE: | |
| raise HTTPException(status_code=413, detail="File too large. Limit is 200MB.") | |
| await file.seek(0) | |
| task_id = str(uuid.uuid4()) | |
| tasks[task_id] = {"status": "processing", "result": None} | |
| file_path = TEMP_DIR / f"{task_id}{file_ext}" | |
| try: | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| finally: | |
| file.file.close() | |
| background_tasks.add_task(processing.run_pipeline, task_id, file_path, tasks) | |
| return {"task_id": task_id} | |
| async def get_task_status(task_id: str): | |
| task = tasks.get(task_id) | |
| if not task: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| return TaskStatus(**task) | |
| async def get_task_result(task_id: str): | |
| task = tasks.get(task_id) | |
| if not task or task.get("status") != "complete": | |
| raise HTTPException(status_code=404, detail="Result not yet available or task not found") | |
| return task.get("result") |