Spaces:
Sleeping
Sleeping
| # merged.py (production-ready for Docker / Hugging Face Spaces) | |
| import os | |
| import time | |
| import threading | |
| import queue | |
| import pathlib | |
| from pathlib import Path | |
| from flask import Flask, request, jsonify, send_from_directory, Response, stream_with_context, render_template | |
| from werkzeug.utils import secure_filename | |
| # Try to import rec_transcribe_extension; we still rely on its utilities | |
| try: | |
| import rec_transcribe_extension as rte | |
| from rec_transcribe_extension import Transcriber, diarization_hook, run_recording | |
| except Exception as e: | |
| # If the module import fails, keep rte=None and catch later to provide friendly error messages | |
| rte = None | |
| Transcriber = None | |
| diarization_hook = None | |
| run_recording = None | |
| print("Warning: failed to import rec_transcribe_extension:", e) | |
| # ---- Environment-driven directories & config ---- | |
| DEFAULT_OUTPUT = os.environ.get("OUTPUT_DIR", "/app/output_transcript_diarization") | |
| OUTPUT_DIR = Path(DEFAULT_OUTPUT) | |
| try: | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| except Exception as ex: | |
| # fallback to /tmp if creation in the requested location fails (common in some runtimes) | |
| OUTPUT_DIR = Path("/tmp/output_transcript_diarization") | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # transcript file path used by SSE endpoint | |
| TRANSCRIPT_FILE = OUTPUT_DIR / "transcript.txt" | |
| # Ensure uploads dir exists (web uploads) | |
| UPLOAD_FOLDER = Path(os.environ.get("UPLOAD_FOLDER", "/app/uploads")) | |
| try: | |
| UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True) | |
| except Exception: | |
| UPLOAD_FOLDER = Path("/tmp/uploads") | |
| UPLOAD_FOLDER.mkdir(parents=True, exist_ok=True) | |
| ALLOWED_EXT = {'.mp3', '.wav', '.m4a', '.aac', '.ogg'} | |
| def allowed_file(filename: str) -> bool: | |
| ext = pathlib.Path(filename).suffix.lower() | |
| return ext in ALLOWED_EXT | |
| # ---- Try to import pyaudio lazily and detect if host audio devices are accessible ---- | |
| LIVE_RECORDING_SUPPORTED = False | |
| _pyaudio = None | |
| try: | |
| import importlib | |
| _pyaudio = importlib.import_module("pyaudio") | |
| # attempt to instantiate PyAudio to confirm it's functional | |
| try: | |
| pa = _pyaudio.PyAudio() | |
| # if there is at least one input device, consider live recording possible | |
| has_input = any(pa.get_device_info_by_index(i).get("maxInputChannels", 0) > 0 | |
| for i in range(pa.get_device_count())) | |
| pa.terminate() | |
| LIVE_RECORDING_SUPPORTED = bool(has_input) | |
| except Exception as e: | |
| LIVE_RECORDING_SUPPORTED = False | |
| print("PyAudio imported but couldn't initialize audio devices:", e) | |
| except Exception: | |
| # pyaudio not available | |
| LIVE_RECORDING_SUPPORTED = False | |
| # ---- Flask app ---- | |
| app = Flask(__name__, static_folder=None) | |
| app.config['UPLOAD_FOLDER'] = str(UPLOAD_FOLDER) | |
| # ---- Shared state ---- | |
| recording_thread = None | |
| recording_lock = threading.Lock() | |
| recording_status = {"recording": False, "live_segments": []} | |
| # ---- Frontend routes ---- | |
| def landing(): | |
| return render_template("landing.html") | |
| def live_page(): | |
| return render_template("index2.html") | |
| def upload_page(): | |
| return render_template("index2_upload.html") | |
| # ---- Device listing (only if supported) ---- | |
| def api_devices(): | |
| if not LIVE_RECORDING_SUPPORTED: | |
| return jsonify({"devices": [], "error": "Live recording not supported in this environment."}), 200 | |
| try: | |
| pa = _pyaudio.PyAudio() | |
| devices = [] | |
| for i in range(pa.get_device_count()): | |
| dev = pa.get_device_info_by_index(i) | |
| if dev.get("maxInputChannels", 0) > 0: | |
| devices.append({"index": dev["index"], "name": dev["name"]}) | |
| pa.terminate() | |
| return jsonify({"devices": devices}) | |
| except Exception as e: | |
| return jsonify({"devices": [], "error": str(e)}), 500 | |
| # ---- Start recording endpoint (guards if pyaudio unavailable) ---- | |
| def api_start_recording(): | |
| global recording_thread | |
| if not LIVE_RECORDING_SUPPORTED or _pyaudio is None: | |
| return jsonify({"error": "Live recording is not supported in this environment."}), 400 | |
| data = request.json or {} | |
| try: | |
| mic = int(data.get("mic")) | |
| except Exception: | |
| return jsonify({"error": "Missing or invalid 'mic' parameter"}), 400 | |
| sys = None | |
| if data.get("sys") not in (None, "", "null"): | |
| try: | |
| sys = int(data.get("sys")) | |
| except Exception: | |
| return jsonify({"error": "Invalid 'sys' parameter"}), 400 | |
| chunk_secs = int(data.get("chunk_secs", 5)) | |
| model = data.get("model", "medium") | |
| no_transcribe = bool(data.get("no_transcribe", False)) | |
| if recording_status["recording"]: | |
| return jsonify({"error": "Already recording"}), 400 | |
| # validate devices using pyaudio | |
| try: | |
| pa = _pyaudio.PyAudio() | |
| except Exception as e: | |
| return jsonify({"error": f"PyAudio initialization failed: {e}"}), 500 | |
| def device_is_valid(device_index): | |
| try: | |
| dev = pa.get_device_info_by_index(device_index) | |
| return dev.get("maxInputChannels", 0) > 0 | |
| except Exception: | |
| return False | |
| if not device_is_valid(mic): | |
| pa.terminate() | |
| return jsonify({"error": f"Microphone device index {mic} not found or has no input channels"}), 400 | |
| if sys is not None and not device_is_valid(sys): | |
| pa.terminate() | |
| return jsonify({"error": f"System device index {sys} not found or has no input channels"}), 400 | |
| pa.terminate() | |
| # ready recording state | |
| recording_status["recording"] = True | |
| recording_status["live_segments"] = [] | |
| stop_event = threading.Event() | |
| def run(): | |
| # monkey-patch worker if module supports it | |
| if rte and hasattr(rte, "chunk_writer_and_transcribe_worker"): | |
| import rec_transcribe_extension as rte_local | |
| orig_worker = rte_local.chunk_writer_and_transcribe_worker | |
| def patched_worker(in_queue, final_frames_list, transcriber, single_channel_label="mic"): | |
| while True: | |
| try: | |
| filename, frames = in_queue.get(timeout=1.0) | |
| except queue.Empty: | |
| if stop_event.is_set() and in_queue.empty(): | |
| break | |
| continue | |
| try: | |
| rte_local.save_wav_from_frames(filename, frames, nchannels=rte_local.CHANNELS) | |
| except Exception: | |
| # best-effort; continue | |
| pass | |
| # diarization and transcription | |
| diar_segments = [] | |
| try: | |
| diar_segments = (rte_local.diarization_hook(str(filename)) or []) | |
| except Exception: | |
| diar_segments = [] | |
| if transcriber and getattr(transcriber, "model", None): | |
| try: | |
| segments, info = transcriber.model.transcribe(str(filename), beam_size=5) | |
| for seg in segments: | |
| seg_start = float(getattr(seg, "start", 0.0)) | |
| seg_end = float(getattr(seg, "end", 0.0)) | |
| seg_text = getattr(seg, "text", "").strip() | |
| speaker = "Unknown" | |
| for d_start, d_end, d_speaker in diar_segments: | |
| if (seg_start < d_end) and (seg_end > d_start): | |
| speaker = d_speaker | |
| break | |
| recording_status["live_segments"].append({ | |
| "start": seg_start, | |
| "end": seg_end, | |
| "speaker": str(speaker), | |
| "text": seg_text | |
| }) | |
| # write to persistent transcript file | |
| try: | |
| with open(TRANSCRIPT_FILE, "a", encoding="utf-8") as tf: | |
| tf.write(f"[{pathlib.Path(filename).name}] {seg_start:.2f}-{seg_end:.2f} Speaker {speaker}: {seg_text}\n") | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| print("Transcription error:", e) | |
| # patched worker exit | |
| rte_local.chunk_writer_and_transcribe_worker = patched_worker | |
| try: | |
| rte_local.stop_event = stop_event | |
| rte_local.run_recording(mic_index=mic, sys_index=sys, chunk_secs=chunk_secs, | |
| model_name=model, no_transcribe=no_transcribe) | |
| finally: | |
| rte_local.chunk_writer_and_transcribe_worker = orig_worker | |
| else: | |
| # fallback: call run_recording if available without monkey patch | |
| try: | |
| if rte and hasattr(rte, "stop_event"): | |
| rte.stop_event = stop_event | |
| if rte and hasattr(rte, "run_recording"): | |
| rte.run_recording(mic_index=mic, sys_index=sys, chunk_secs=chunk_secs, | |
| model_name=model, no_transcribe=no_transcribe) | |
| except Exception as e: | |
| print("run_recording error:", e) | |
| recording_status["recording"] = False | |
| recording_thread_local = threading.Thread(target=run, daemon=True) | |
| recording_thread_local.start() | |
| # store reference globally so stop logic can use it | |
| global recording_thread | |
| recording_thread = recording_thread_local | |
| return jsonify({"ok": True}) | |
| # Stop recording | |
| def api_stop_recording(): | |
| if rte and hasattr(rte, "stop_event") and rte.stop_event: | |
| try: | |
| rte.stop_event.set() | |
| except Exception: | |
| pass | |
| return jsonify({"ok": True}) | |
| # recording status | |
| def api_recording_status(): | |
| return jsonify({ | |
| "recording": recording_status.get("recording", False), | |
| "live_segments": recording_status.get("live_segments", []) | |
| }) | |
| # ---- Upload endpoint (works in Spaces) ---- | |
| def api_upload_file(): | |
| if 'file' not in request.files: | |
| return jsonify(success=False, error="No file part"), 400 | |
| f = request.files['file'] | |
| if f.filename == '': | |
| return jsonify(success=False, error="Empty filename"), 400 | |
| filename = secure_filename(f.filename) | |
| if not allowed_file(filename): | |
| return jsonify(success=False, error="Extension not allowed"), 400 | |
| ts = int(time.time() * 1000) | |
| saved_name = f"{ts}_{filename}" | |
| save_path = Path(app.config['UPLOAD_FOLDER']) / saved_name | |
| try: | |
| f.save(str(save_path)) | |
| except Exception as e: | |
| return jsonify(success=False, error=f"Failed to save file: {e}"), 500 | |
| url = f"/uploads/{saved_name}" | |
| return jsonify(success=True, url=url, filename=saved_name) | |
| # Serve uploaded files | |
| def uploaded_file(filename): | |
| return send_from_directory(app.config['UPLOAD_FOLDER'], filename, as_attachment=False) | |
| # ---- Transcribe an uploaded file in a paced 'live' manner (works in Spaces) ---- | |
| def api_start_transcribe_file(): | |
| data = request.json or {} | |
| filename = data.get("filename") | |
| if not filename: | |
| return jsonify({"error": "Missing filename"}), 400 | |
| file_path = OUTPUT_DIR / filename | |
| # if file was uploaded to uploads folder, prefer that path | |
| uploaded_path = Path(app.config['UPLOAD_FOLDER']) / filename | |
| if uploaded_path.exists(): | |
| file_path = uploaded_path | |
| if not file_path.exists(): | |
| return jsonify({"error": "File not found"}), 404 | |
| if recording_status.get("recording"): | |
| return jsonify({"error": "Busy"}), 400 | |
| def worker(): | |
| try: | |
| recording_status["recording"] = True | |
| recording_status["live_segments"] = [] | |
| transcriber = Transcriber() if Transcriber else None | |
| diar_segments = diarization_hook(str(file_path)) if diarization_hook else [] | |
| if transcriber and getattr(transcriber, "model", None): | |
| segments, _ = transcriber.model.transcribe(str(file_path), beam_size=5) | |
| start_clock = time.time() | |
| for seg in segments: | |
| wait_for = seg.start - (time.time() - start_clock) | |
| if wait_for > 0: | |
| time.sleep(wait_for) | |
| speaker = "Unknown" | |
| for d_start, d_end, d_label in (diar_segments or []): | |
| if (seg.start < d_end) and (seg.end > d_start): | |
| speaker = d_label | |
| break | |
| seg_obj = {"start": float(seg.start), "end": float(seg.end), "speaker": speaker, "text": seg.text.strip()} | |
| recording_status["live_segments"].append(seg_obj) | |
| # append to transcript file for SSE streaming | |
| try: | |
| with open(TRANSCRIPT_FILE, "a", encoding="utf-8") as tf: | |
| tf.write(f"[{file_path.name}] {seg.start:.2f}-{seg.end:.2f} Speaker {speaker}: {seg.text.strip()}\n") | |
| except Exception: | |
| pass | |
| recording_status["recording"] = False | |
| except Exception as e: | |
| print("Error in file transcription:", e) | |
| recording_status["recording"] = False | |
| threading.Thread(target=worker, daemon=True).start() | |
| return jsonify({"ok": True}) | |
| # Stop (generic) | |
| def stop_recording(): | |
| if rte and hasattr(rte, 'stop_event') and rte.stop_event is not None: | |
| try: | |
| rte.stop_event.set() | |
| except Exception: | |
| pass | |
| return jsonify(success=True, message="Stop signal sent") | |
| # SSE tailer | |
| def tail_transcript_file(path, stop_cond_fn=None): | |
| last_pos = 0 | |
| sent_initial = False | |
| while True: | |
| if stop_cond_fn and stop_cond_fn(): | |
| break | |
| if os.path.exists(path): | |
| with open(path, "r", encoding="utf-8", errors="ignore") as fh: | |
| fh.seek(last_pos) | |
| lines = fh.readlines() | |
| if lines: | |
| for ln in lines: | |
| ln = ln.strip() | |
| if ln: | |
| yield f"data: {ln}\n\n" | |
| last_pos = fh.tell() | |
| sent_initial = True | |
| else: | |
| time.sleep(0.25) | |
| else: | |
| if not sent_initial: | |
| yield "data: [info] Transcript file not yet created. Waiting...\n\n" | |
| sent_initial = True | |
| time.sleep(0.5) | |
| yield "data: [info] Transcription ended.\n\n" | |
| def events(): | |
| transcript_path = str(TRANSCRIPT_FILE) | |
| def stop_fn(): | |
| cond = False | |
| try: | |
| cond = (rte and hasattr(rte, 'stop_event') and rte.stop_event is not None and rte.stop_event.is_set()) | |
| except Exception: | |
| cond = False | |
| t_alive = False | |
| try: | |
| t_alive = 'recording_thread' in globals() and recording_thread is not None and recording_thread.is_alive() | |
| except Exception: | |
| t_alive = False | |
| return (cond and not t_alive) | |
| return Response(stream_with_context(tail_transcript_file(transcript_path, stop_cond_fn=stop_fn)), mimetype="text/event-stream") | |
| def status(): | |
| running = False | |
| try: | |
| running = recording_status.get("recording", False) | |
| except Exception: | |
| running = False | |
| return jsonify(running=running) | |
| # Final-files listing (for UI) | |
| def api_final_files(): | |
| files = [] | |
| # list files from OUTPUT_DIR and uploads | |
| try: | |
| out_dir = OUTPUT_DIR | |
| for fname in os.listdir(out_dir): | |
| if fname.endswith(".wav") or fname.endswith(".txt"): | |
| files.append({"name": fname, "path": f"/static/{fname}", "url": f"/static/{fname}"}) | |
| except Exception: | |
| pass | |
| # also list uploaded files | |
| try: | |
| for fname in os.listdir(app.config['UPLOAD_FOLDER']): | |
| if fname.endswith(".wav") or fname.endswith(".mp3") or fname.endswith(".txt"): | |
| files.append({"name": fname, "path": f"/uploads/{fname}", "url": f"/uploads/{fname}"}) | |
| except Exception: | |
| pass | |
| return jsonify({"files": files}) | |
| # Serve static final-files from OUTPUT_DIR (if you want to expose them at /static/<file>) | |
| def static_files(filename): | |
| return send_from_directory(str(OUTPUT_DIR), filename) | |
| # Run only when debugging locally; in production we use gunicorn | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), threaded=True) |