AudioTransDiar / merged.py
prthm11's picture
Update merged.py
0ba7c0e verified
# merged.py (production-ready for Docker / Hugging Face Spaces)
import os
import time
import threading
import queue
import pathlib
from pathlib import Path
from flask import Flask, request, jsonify, send_from_directory, Response, stream_with_context, render_template
from werkzeug.utils import secure_filename
# Try to import rec_transcribe_extension; we still rely on its utilities
try:
import rec_transcribe_extension as rte
from rec_transcribe_extension import Transcriber, diarization_hook, run_recording
except Exception as e:
# If the module import fails, keep rte=None and catch later to provide friendly error messages
rte = None
Transcriber = None
diarization_hook = None
run_recording = None
print("Warning: failed to import rec_transcribe_extension:", e)
# ---- Environment-driven directories & config ----
DEFAULT_OUTPUT = os.environ.get("OUTPUT_DIR", "/app/output_transcript_diarization")
OUTPUT_DIR = Path(DEFAULT_OUTPUT)
try:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
except Exception as ex:
# fallback to /tmp if creation in the requested location fails (common in some runtimes)
OUTPUT_DIR = Path("/tmp/output_transcript_diarization")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# transcript file path used by SSE endpoint
TRANSCRIPT_FILE = OUTPUT_DIR / "transcript.txt"
# Ensure uploads dir exists (web uploads)
UPLOAD_FOLDER = Path(os.environ.get("UPLOAD_FOLDER", "/app/uploads"))
try:
UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
except Exception:
UPLOAD_FOLDER = Path("/tmp/uploads")
UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True)
ALLOWED_EXT = {'.mp3', '.wav', '.m4a', '.aac', '.ogg'}
def allowed_file(filename: str) -> bool:
ext = pathlib.Path(filename).suffix.lower()
return ext in ALLOWED_EXT
# ---- Try to import pyaudio lazily and detect if host audio devices are accessible ----
LIVE_RECORDING_SUPPORTED = False
_pyaudio = None
try:
import importlib
_pyaudio = importlib.import_module("pyaudio")
# attempt to instantiate PyAudio to confirm it's functional
try:
pa = _pyaudio.PyAudio()
# if there is at least one input device, consider live recording possible
has_input = any(pa.get_device_info_by_index(i).get("maxInputChannels", 0) > 0
for i in range(pa.get_device_count()))
pa.terminate()
LIVE_RECORDING_SUPPORTED = bool(has_input)
except Exception as e:
LIVE_RECORDING_SUPPORTED = False
print("PyAudio imported but couldn't initialize audio devices:", e)
except Exception:
# pyaudio not available
LIVE_RECORDING_SUPPORTED = False
# ---- Flask app ----
app = Flask(__name__, static_folder=None)
app.config['UPLOAD_FOLDER'] = str(UPLOAD_FOLDER)
# ---- Shared state ----
recording_thread = None
recording_lock = threading.Lock()
recording_status = {"recording": False, "live_segments": []}
# ---- Frontend routes ----
@app.route("/")
def landing():
return render_template("landing.html")
@app.route("/live")
def live_page():
return render_template("index2.html")
@app.route("/upload")
def upload_page():
return render_template("index2_upload.html")
# ---- Device listing (only if supported) ----
@app.route("/api/devices", methods=["GET"])
def api_devices():
if not LIVE_RECORDING_SUPPORTED:
return jsonify({"devices": [], "error": "Live recording not supported in this environment."}), 200
try:
pa = _pyaudio.PyAudio()
devices = []
for i in range(pa.get_device_count()):
dev = pa.get_device_info_by_index(i)
if dev.get("maxInputChannels", 0) > 0:
devices.append({"index": dev["index"], "name": dev["name"]})
pa.terminate()
return jsonify({"devices": devices})
except Exception as e:
return jsonify({"devices": [], "error": str(e)}), 500
# ---- Start recording endpoint (guards if pyaudio unavailable) ----
@app.route("/api/start-recording", methods=["POST"])
def api_start_recording():
global recording_thread
if not LIVE_RECORDING_SUPPORTED or _pyaudio is None:
return jsonify({"error": "Live recording is not supported in this environment."}), 400
data = request.json or {}
try:
mic = int(data.get("mic"))
except Exception:
return jsonify({"error": "Missing or invalid 'mic' parameter"}), 400
sys = None
if data.get("sys") not in (None, "", "null"):
try:
sys = int(data.get("sys"))
except Exception:
return jsonify({"error": "Invalid 'sys' parameter"}), 400
chunk_secs = int(data.get("chunk_secs", 5))
model = data.get("model", "medium")
no_transcribe = bool(data.get("no_transcribe", False))
if recording_status["recording"]:
return jsonify({"error": "Already recording"}), 400
# validate devices using pyaudio
try:
pa = _pyaudio.PyAudio()
except Exception as e:
return jsonify({"error": f"PyAudio initialization failed: {e}"}), 500
def device_is_valid(device_index):
try:
dev = pa.get_device_info_by_index(device_index)
return dev.get("maxInputChannels", 0) > 0
except Exception:
return False
if not device_is_valid(mic):
pa.terminate()
return jsonify({"error": f"Microphone device index {mic} not found or has no input channels"}), 400
if sys is not None and not device_is_valid(sys):
pa.terminate()
return jsonify({"error": f"System device index {sys} not found or has no input channels"}), 400
pa.terminate()
# ready recording state
recording_status["recording"] = True
recording_status["live_segments"] = []
stop_event = threading.Event()
def run():
# monkey-patch worker if module supports it
if rte and hasattr(rte, "chunk_writer_and_transcribe_worker"):
import rec_transcribe_extension as rte_local
orig_worker = rte_local.chunk_writer_and_transcribe_worker
def patched_worker(in_queue, final_frames_list, transcriber, single_channel_label="mic"):
while True:
try:
filename, frames = in_queue.get(timeout=1.0)
except queue.Empty:
if stop_event.is_set() and in_queue.empty():
break
continue
try:
rte_local.save_wav_from_frames(filename, frames, nchannels=rte_local.CHANNELS)
except Exception:
# best-effort; continue
pass
# diarization and transcription
diar_segments = []
try:
diar_segments = (rte_local.diarization_hook(str(filename)) or [])
except Exception:
diar_segments = []
if transcriber and getattr(transcriber, "model", None):
try:
segments, info = transcriber.model.transcribe(str(filename), beam_size=5)
for seg in segments:
seg_start = float(getattr(seg, "start", 0.0))
seg_end = float(getattr(seg, "end", 0.0))
seg_text = getattr(seg, "text", "").strip()
speaker = "Unknown"
for d_start, d_end, d_speaker in diar_segments:
if (seg_start < d_end) and (seg_end > d_start):
speaker = d_speaker
break
recording_status["live_segments"].append({
"start": seg_start,
"end": seg_end,
"speaker": str(speaker),
"text": seg_text
})
# write to persistent transcript file
try:
with open(TRANSCRIPT_FILE, "a", encoding="utf-8") as tf:
tf.write(f"[{pathlib.Path(filename).name}] {seg_start:.2f}-{seg_end:.2f} Speaker {speaker}: {seg_text}\n")
except Exception:
pass
except Exception as e:
print("Transcription error:", e)
# patched worker exit
rte_local.chunk_writer_and_transcribe_worker = patched_worker
try:
rte_local.stop_event = stop_event
rte_local.run_recording(mic_index=mic, sys_index=sys, chunk_secs=chunk_secs,
model_name=model, no_transcribe=no_transcribe)
finally:
rte_local.chunk_writer_and_transcribe_worker = orig_worker
else:
# fallback: call run_recording if available without monkey patch
try:
if rte and hasattr(rte, "stop_event"):
rte.stop_event = stop_event
if rte and hasattr(rte, "run_recording"):
rte.run_recording(mic_index=mic, sys_index=sys, chunk_secs=chunk_secs,
model_name=model, no_transcribe=no_transcribe)
except Exception as e:
print("run_recording error:", e)
recording_status["recording"] = False
recording_thread_local = threading.Thread(target=run, daemon=True)
recording_thread_local.start()
# store reference globally so stop logic can use it
global recording_thread
recording_thread = recording_thread_local
return jsonify({"ok": True})
# Stop recording
@app.route("/api/stop-recording", methods=["POST"])
def api_stop_recording():
if rte and hasattr(rte, "stop_event") and rte.stop_event:
try:
rte.stop_event.set()
except Exception:
pass
return jsonify({"ok": True})
# recording status
@app.route("/api/recording-status")
def api_recording_status():
return jsonify({
"recording": recording_status.get("recording", False),
"live_segments": recording_status.get("live_segments", [])
})
# ---- Upload endpoint (works in Spaces) ----
@app.route("/api/upload", methods=["POST"])
def api_upload_file():
if 'file' not in request.files:
return jsonify(success=False, error="No file part"), 400
f = request.files['file']
if f.filename == '':
return jsonify(success=False, error="Empty filename"), 400
filename = secure_filename(f.filename)
if not allowed_file(filename):
return jsonify(success=False, error="Extension not allowed"), 400
ts = int(time.time() * 1000)
saved_name = f"{ts}_{filename}"
save_path = Path(app.config['UPLOAD_FOLDER']) / saved_name
try:
f.save(str(save_path))
except Exception as e:
return jsonify(success=False, error=f"Failed to save file: {e}"), 500
url = f"/uploads/{saved_name}"
return jsonify(success=True, url=url, filename=saved_name)
# Serve uploaded files
@app.route("/uploads/<path:filename>")
def uploaded_file(filename):
return send_from_directory(app.config['UPLOAD_FOLDER'], filename, as_attachment=False)
# ---- Transcribe an uploaded file in a paced 'live' manner (works in Spaces) ----
@app.route("/api/start-transcribe-file", methods=["POST"])
def api_start_transcribe_file():
data = request.json or {}
filename = data.get("filename")
if not filename:
return jsonify({"error": "Missing filename"}), 400
file_path = OUTPUT_DIR / filename
# if file was uploaded to uploads folder, prefer that path
uploaded_path = Path(app.config['UPLOAD_FOLDER']) / filename
if uploaded_path.exists():
file_path = uploaded_path
if not file_path.exists():
return jsonify({"error": "File not found"}), 404
if recording_status.get("recording"):
return jsonify({"error": "Busy"}), 400
def worker():
try:
recording_status["recording"] = True
recording_status["live_segments"] = []
transcriber = Transcriber() if Transcriber else None
diar_segments = diarization_hook(str(file_path)) if diarization_hook else []
if transcriber and getattr(transcriber, "model", None):
segments, _ = transcriber.model.transcribe(str(file_path), beam_size=5)
start_clock = time.time()
for seg in segments:
wait_for = seg.start - (time.time() - start_clock)
if wait_for > 0:
time.sleep(wait_for)
speaker = "Unknown"
for d_start, d_end, d_label in (diar_segments or []):
if (seg.start < d_end) and (seg.end > d_start):
speaker = d_label
break
seg_obj = {"start": float(seg.start), "end": float(seg.end), "speaker": speaker, "text": seg.text.strip()}
recording_status["live_segments"].append(seg_obj)
# append to transcript file for SSE streaming
try:
with open(TRANSCRIPT_FILE, "a", encoding="utf-8") as tf:
tf.write(f"[{file_path.name}] {seg.start:.2f}-{seg.end:.2f} Speaker {speaker}: {seg.text.strip()}\n")
except Exception:
pass
recording_status["recording"] = False
except Exception as e:
print("Error in file transcription:", e)
recording_status["recording"] = False
threading.Thread(target=worker, daemon=True).start()
return jsonify({"ok": True})
# Stop (generic)
@app.route("/stop", methods=["POST"])
def stop_recording():
if rte and hasattr(rte, 'stop_event') and rte.stop_event is not None:
try:
rte.stop_event.set()
except Exception:
pass
return jsonify(success=True, message="Stop signal sent")
# SSE tailer
def tail_transcript_file(path, stop_cond_fn=None):
last_pos = 0
sent_initial = False
while True:
if stop_cond_fn and stop_cond_fn():
break
if os.path.exists(path):
with open(path, "r", encoding="utf-8", errors="ignore") as fh:
fh.seek(last_pos)
lines = fh.readlines()
if lines:
for ln in lines:
ln = ln.strip()
if ln:
yield f"data: {ln}\n\n"
last_pos = fh.tell()
sent_initial = True
else:
time.sleep(0.25)
else:
if not sent_initial:
yield "data: [info] Transcript file not yet created. Waiting...\n\n"
sent_initial = True
time.sleep(0.5)
yield "data: [info] Transcription ended.\n\n"
@app.route("/events")
def events():
transcript_path = str(TRANSCRIPT_FILE)
def stop_fn():
cond = False
try:
cond = (rte and hasattr(rte, 'stop_event') and rte.stop_event is not None and rte.stop_event.is_set())
except Exception:
cond = False
t_alive = False
try:
t_alive = 'recording_thread' in globals() and recording_thread is not None and recording_thread.is_alive()
except Exception:
t_alive = False
return (cond and not t_alive)
return Response(stream_with_context(tail_transcript_file(transcript_path, stop_cond_fn=stop_fn)), mimetype="text/event-stream")
@app.route("/status")
def status():
running = False
try:
running = recording_status.get("recording", False)
except Exception:
running = False
return jsonify(running=running)
# Final-files listing (for UI)
@app.route("/api/final-files")
def api_final_files():
files = []
# list files from OUTPUT_DIR and uploads
try:
out_dir = OUTPUT_DIR
for fname in os.listdir(out_dir):
if fname.endswith(".wav") or fname.endswith(".txt"):
files.append({"name": fname, "path": f"/static/{fname}", "url": f"/static/{fname}"})
except Exception:
pass
# also list uploaded files
try:
for fname in os.listdir(app.config['UPLOAD_FOLDER']):
if fname.endswith(".wav") or fname.endswith(".mp3") or fname.endswith(".txt"):
files.append({"name": fname, "path": f"/uploads/{fname}", "url": f"/uploads/{fname}"})
except Exception:
pass
return jsonify({"files": files})
# Serve static final-files from OUTPUT_DIR (if you want to expose them at /static/<file>)
@app.route('/static/<path:filename>')
def static_files(filename):
return send_from_directory(str(OUTPUT_DIR), filename)
# Run only when debugging locally; in production we use gunicorn
if __name__ == "__main__":
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), threaded=True)