AudioTransDiar / app.py
prthm11's picture
Upload 12 files
4207399 verified
from flask import Flask, request, jsonify, send_from_directory, render_template
import threading
import time
import os
import queue
from pathlib import Path
import pyaudio
from werkzeug.utils import secure_filename
from rec_transcribe_extension import Transcriber, diarization_hook
from rec_transcribe_extension import (
list_input_devices,
run_recording,
OUTPUT_DIR,
CHUNKS_DIR,
FINAL_WAV,)
app = Flask(__name__)
recording_thread = None
recording_running = False
recording_status = {
"recording": False,
"live_segments": []
}
# ------ Device Listing API ------
@app.route("/api/devices", methods=["GET"])
def api_devices():
pa = pyaudio.PyAudio()
devices = []
for i in range(pa.get_device_count()):
dev = pa.get_device_info_by_index(i)
if dev.get("maxInputChannels", 0) > 0:
devices.append({"index": dev["index"], "name": dev["name"]})
pa.terminate()
return jsonify({"devices": devices})
# --- Start recording ---
@app.route("/api/start-recording", methods=["POST"])
def api_start_recording():
global recording_thread, stop_event, recording_status
data = request.json
# Validate required fields
try:
mic = int(data.get("mic"))
except Exception:
return jsonify({"error": "Missing or invalid 'mic' parameter"}), 400
# sys = int(data["sys"]) if data.get("sys") not in (None, "", "null") else None
sys = None
if data.get("sys") not in (None, "", "null"):
try:
sys = int(data.get("sys"))
except Exception:
return jsonify({"error": "Invalid 'sys' parameter"}), 400
chunk_secs = int(data.get("chunk_secs", 5))
model = data.get("model", "medium")
no_transcribe = bool(data.get("no_transcribe", False))
if recording_status["recording"]:
return jsonify({"error": "Already recording"}), 400
# --- Validate that requested devices exist and have input channels ---
try:
pa = pyaudio.PyAudio()
except Exception as e:
return jsonify({"error": f"PyAudio initialization failed: {e}"}), 500
def device_is_valid(device_index):
try:
dev = pa.get_device_info_by_index(device_index)
return dev.get("maxInputChannels", 0) > 0
except Exception:
return False
if not device_is_valid(mic):
pa.terminate()
return jsonify({"error": f"Microphone device index {mic} not found or has no input channels"}), 400
if sys is not None and not device_is_valid(sys):
pa.terminate()
return jsonify({"error": f"System device index {sys} not found or has no input channels"}), 400
pa.terminate()
# Reset state
recording_status["recording"] = True
recording_status["live_segments"] = []
stop_event = threading.Event()
def run():
# Patch: update live_segments after each chunk
from rec_transcribe_extension import chunk_writer_and_transcribe_worker
# Monkey-patch chunk_writer_and_transcribe_worker to update live_segments
import rec_transcribe_extension as rte
orig_worker = rte.chunk_writer_and_transcribe_worker
def patched_worker(in_queue, final_frames_list, transcriber, single_channel_label="mic"):
while True:
try:
filename, frames = in_queue.get(timeout=1.0)
except queue.Empty:
if stop_event.is_set() and in_queue.empty():
break
continue
rte.save_wav_from_frames(
filename, frames, nchannels=rte.CHANNELS)
final_frames_list.extend(frames)
diar = rte.diarization_hook(str(filename))
diar_segments = diar if diar else []
# Transcribe chunk and get segments with timestamps
if transcriber and transcriber.model:
try:
segments, info = transcriber.model.transcribe(
str(filename), beam_size=5)
for seg in segments:
seg_start = seg.start
seg_end = seg.end
seg_text = seg.text.strip()
speaker = "Unknown"
for d_start, d_end, d_speaker in diar_segments:
if (seg_start < d_end) and (seg_end > d_start):
speaker = d_speaker
break
# Update live_segments for frontend
recording_status["live_segments"].append({
"start": float(seg_start),
"end": float(seg_end),
"speaker": str(speaker),
"text": seg_text
})
# Write to transcript file as before
line = f"[{filename.name}] {seg_start:.2f}-{seg_end:.2f} Speaker {speaker}: {seg_text}\n"
with open(rte.TRANSCRIPT_FILE, "a", encoding="utf-8") as tf:
tf.write(line)
except Exception as e:
print(f"Transcription error for {filename.name}: {e}")
print("Chunk writer/transcriber worker exiting.")
rte.chunk_writer_and_transcribe_worker = patched_worker
try:
rte.stop_event = stop_event
run_recording(mic_index=mic, sys_index=sys, chunk_secs=chunk_secs,
model_name=model, no_transcribe=no_transcribe)
finally:
rte.chunk_writer_and_transcribe_worker = orig_worker
recording_status["recording"] = False
recording_thread = threading.Thread(target=run, daemon=True)
recording_thread.start()
return jsonify({"ok": True})
# --- Stop recording ---
@app.route("/api/stop-recording", methods=["POST"])
def api_stop_recording():
global stop_event
if stop_event:
stop_event.set()
return jsonify({"ok": True})
# --- Poll status ---
@app.route("/api/recording-status")
def api_recording_status():
return jsonify(recording_status)
# # serve saved uploads at /uploads/<filename>
# @app.route('/uploads/<path:filename>')
# def serve_uploaded(filename):
# return send_from_directory(str(OUTPUT_DIR), filename)
# # --- upload pre-recorded files ---
# @app.route("/api/upload", methods=["POST"])
# def api_upload_file():
# """
# Accept a single file (form-data 'file'), save it into OUTPUT_DIR and return json
# { ok: True, filename: "<saved_name>", url: "/static/<saved_name>" }.
# """
# if 'file' not in request.files:
# return jsonify({"error": "No file provided"}), 400
# f = request.files['file']
# if f.filename == '':
# return jsonify({"error": "Empty filename"}), 400
# safe_name = secure_filename(f.filename)
# # prefix timestamp to avoid collisions
# ts = int(time.time() * 1000)
# saved_name = f"{ts}_{safe_name}"
# saved_path = OUTPUT_DIR / saved_name
# try:
# f.save(str(saved_path))
# except Exception as e:
# return jsonify({"error": f"Failed to save file: {e}"}), 500
# return jsonify({"ok": True, "filename": saved_name, "url": f"/static/{saved_name}"})
# # --- Start server-side paced transcription for a saved WAV/MP3 file ---
# @app.route("/api/start-transcribe-file", methods=["POST"])
# def api_start_transcribe_file():
# """
# POST JSON { filename: "<saved_name>" }
# Spawns a background thread that transcribes the file using the Transcriber,
# and appends transcribed segments (with start/end/speaker/text) into
# recording_status["live_segments"] while setting recording_status["recording"]=True.
# The worker will pace segments to approximate 'live' streaming using seg.start timestamps.
# """
# global recording_status
# data = request.json or {}
# filename = data.get("filename")
# print("DEBUG: /api/start-transcribe-file called with:", filename, flush=True)
# if not filename:
# return jsonify({"error": "Missing 'filename'"}), 400
# file_path = OUTPUT_DIR / filename
# if not file_path.exists():
# return jsonify({"error": "File not found on server"}), 404
# # prevent concurrent transcription runs
# if recording_status.get("recording"):
# return jsonify({"error": "Another transcription/recording is already running"}), 400
# def worker():
# try:
# recording_status["recording"] = True
# recording_status["live_segments"] = []
# transcriber = Transcriber()
# if not transcriber.model:
# # model not loaded/available
# recording_status["recording"] = False
# print("Transcription model not available; cannot transcribe file.")
# return
# # perform diarization if available
# diar_segments = diarization_hook(str(file_path)) or []
# # get segments from model
# try:
# segments, info = transcriber.model.transcribe(str(file_path), beam_size=5)
# except Exception as e:
# print("Error during transcription:", e)
# recording_status["recording"] = False
# return
# # Stream the segments into recording_status with timing
# start_clock = time.time()
# for seg in segments:
# # seg.start is seconds into the audio
# wait_for = seg.start - (time.time() - start_clock)
# if wait_for > 0:
# time.sleep(wait_for)
# # map speaker using diarization segments (best-effort overlap)
# speaker = "Unknown"
# for d_start, d_end, d_label in diar_segments:
# if (seg.start < d_end) and (seg.end > d_start):
# speaker = d_label
# break
# seg_obj = {
# "start": float(seg.start),
# "end": float(seg.end),
# "speaker": str(speaker),
# "text": seg.text.strip()
# }
# # append to shared status for frontend polling
# recording_status.setdefault("live_segments", []).append(seg_obj)
# # also append to transcript file for persistence (optional)
# with open(rec_transcribe_extension.TRANSCRIPT_FILE, "a", encoding="utf-8") as tf:
# line = f"[{filename}] {seg.start:.2f}-{seg.end:.2f} Speaker {speaker}: {seg.text.strip()}\n"
# tf.write(line)
# # done streaming
# recording_status["recording"] = False
# except Exception as e:
# print("Error in transcription worker:", e)
# recording_status["recording"] = False
# t = threading.Thread(target=worker, daemon=True)
# t.start()
# return jsonify({"ok": True})
# --- List final files ---
@app.route("/api/final-files")
def api_final_files():
files = []
out_dir = OUTPUT_DIR
for fname in os.listdir(out_dir):
if fname.endswith(".wav") or fname.endswith(".txt"):
files.append(
{"name": fname, "path": f"/static/{fname}", "url": f"/static/{fname}"})
return jsonify({"files": files})
# --- Serve static files (WAV, TXT) ---
@app.route('/static/<path:filename>')
def static_files(filename):
return send_from_directory(OUTPUT_DIR, filename)
# --- Serve the frontend ---
@app.route("/")
def index():
return render_template("index2.html")
if __name__ == "__main__":
app.run(port=5000, debug=True)