audio-to-text / main.py
shivam98's picture
Update main.py
a617253 verified
import uuid
import shutil
import os
import logging
from fastapi import FastAPI, UploadFile, File, BackgroundTasks
from fastapi.responses import JSONResponse
from transformers import pipeline
import tempfile
from typing import Dict
# ---------------- Logging Configuration ----------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
# ---------------- App Setup ----------------
app = FastAPI()
transcriptions: Dict[str, Dict] = {}
# ---------------- Whisper Pipeline ----------------
logger.info("Loading Whisper ASR pipeline...")
asr = pipeline("automatic-speech-recognition", model="openai/whisper-small", generate_kwargs={"language": "en"})
logger.info("Whisper pipeline loaded.")
# ---------------- Transcription Function ----------------
def transcribe_audio(file_path: str, task_id: str):
logger.info(f"[{task_id}] Starting transcription for {file_path}")
try:
result = asr(file_path, return_timestamps="word")
text = result["text"]
segments = result.get("chunks", [])
srt_lines = []
for i, seg in enumerate(segments):
start = seg["timestamp"][0]
end = seg["timestamp"][1]
content = seg["text"].strip()
if start is None or end is None:
logger.warning(f"[{task_id}] Skipping segment {i+1} due to missing timestamp")
continue
srt_lines.append(f"{i+1}")
srt_lines.append(f"{format_timestamp(start)} --> {format_timestamp(end)}")
srt_lines.append(content)
srt_lines.append("")
srt_result = "\n".join(srt_lines)
transcriptions[task_id] = {
"status": "completed",
"text": text,
"segments": segments,
"srt": srt_result
}
logger.info(f"[{task_id}] Transcription completed successfully.")
except Exception as e:
transcriptions[task_id] = {
"status": "failed",
"error": str(e)
}
logger.error(f"[{task_id}] Transcription failed: {e}")
finally:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"[{task_id}] Audio file deleted after processing")
# ---------------- Timestamp Formatter ----------------
def format_timestamp(seconds: float):
if seconds is None:
return "00:00:00,000"
ms = int((seconds - int(seconds)) * 1000)
h, rem = divmod(int(seconds), 3600)
m, s = divmod(rem, 60)
return f"{h:02}:{m:02}:{s:02},{ms:03}"
# ---------------- Routes ----------------
@app.post("/upload/")
async def upload_audio(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
task_id = str(uuid.uuid4())
temp_dir = tempfile.mkdtemp()
file_path = os.path.join(temp_dir, file.filename)
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
logger.info(f"[{task_id}] File '{file.filename}' uploaded and saved to {file_path}")
transcriptions[task_id] = {"status": "processing"}
background_tasks.add_task(transcribe_audio, file_path, task_id)
return {"task_id": task_id}
@app.get("/ping")
async def ping():
return {"status": "success", "message": "Dummy API is working"}
@app.get("/status/{task_id}")
async def get_status(task_id: str):
if task_id not in transcriptions:
logger.warning(f"[{task_id}] Status check failed: Invalid ID")
return JSONResponse(content={"error": "Invalid ID"}, status_code=404)
logger.info(f"[{task_id}] Status checked: {transcriptions[task_id]['status']}")
return {"task_id": task_id, "status": transcriptions[task_id]["status"]}
@app.get("/result/{task_id}")
async def get_result(task_id: str):
if task_id not in transcriptions:
logger.warning(f"[{task_id}] Result request failed: Invalid ID")
return JSONResponse(content={"error": "Invalid ID"}, status_code=404)
entry = transcriptions[task_id]
if entry["status"] != "completed":
logger.info(f"[{task_id}] Result requested but not ready (status: {entry['status']})")
return JSONResponse(content={"error": "Transcription not ready"}, status_code=400)
result = {
"text": entry["text"],
"segments": entry["segments"],
"srt": entry["srt"]
}
logger.info(f"[{task_id}] Result downloaded. Cleaning up transcription data.")
del transcriptions[task_id]
return result