staraks commited on
Commit
9a6938c
·
verified ·
1 Parent(s): ace3539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -128
app.py CHANGED
@@ -1,62 +1,36 @@
1
  #!/usr/bin/env python3
2
  """
3
- Medical transcription service (whisper) with NumPy compatibility fixes.
4
 
5
- Behavior:
6
- - If invoked with an audio path argument, transcribes that file (CLI mode).
7
- - If run with no args, starts a small Flask server with POST /transcribe for uploads.
8
- This avoids the container failing immediately when no CLI args are provided.
9
- """
10
 
11
- import sys
 
12
  import os
13
- import json
 
 
 
14
  import logging
15
  from pathlib import Path
16
- import tempfile
 
 
17
 
18
- # Basic logging
19
  logging.basicConfig(level=logging.INFO)
20
- logger = logging.getLogger("med-asr")
21
-
22
- # Check numpy version early
23
- try:
24
- import numpy as np
25
- nv = tuple(map(int, np.__version__.split(".")[:2]))
26
- if nv[0] >= 2:
27
- logger.warning("NumPy version is %s. Some extensions compiled against NumPy 1.x may fail. "
28
- "If you see runtime errors, pin numpy<2 in requirements and rebuild the image.", np.__version__)
29
- except Exception:
30
- # keep going; the Dockerfile installs numpy<2 so this should not happen.
31
- pass
32
-
33
- # Import heavy deps and provide helpful message if missing
34
- try:
35
- import torch
36
- except ModuleNotFoundError:
37
- sys.stderr.write("Missing dependency: PyTorch is not installed. For CPU install, add to requirements.txt:\n"
38
- " --extra-index-url https://download.pytorch.org/whl/cpu\n"
39
- " torch==2.2.0+cpu\n")
40
- sys.exit(1)
41
-
42
- try:
43
- import whisper
44
- import librosa
45
- import soundfile as sf
46
- from rapidfuzz import process, fuzz
47
- except Exception as e:
48
- logger.exception("Failed to import dependencies: %s", e)
49
- raise
50
-
51
- from flask import Flask, request, jsonify
52
-
53
- # Config (tune these)
54
- MODEL_NAME = os.environ.get("WHISPER_MODEL", "large-v2")
55
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
56
- SAMPLE_RATE = 16000
57
- BEAM_SIZE = 5
58
- TEMPERATURE = 0.0
59
  MED_VOCAB_PATH = Path("medical_vocab.txt")
 
 
60
  ABBREV_MAP = {
61
  "h/o": "history of",
62
  "c/o": "complains of",
@@ -67,6 +41,12 @@ ABBREV_MAP = {
67
  "bp": "blood pressure",
68
  }
69
 
 
 
 
 
 
 
70
  def load_med_vocab(path=MED_VOCAB_PATH):
71
  if not path.exists():
72
  return []
@@ -75,22 +55,17 @@ def load_med_vocab(path=MED_VOCAB_PATH):
75
 
76
  MED_VOCAB = load_med_vocab()
77
 
78
- def whisper_transcribe(model, audio_path, beam_size=BEAM_SIZE, temperature=TEMPERATURE):
79
- return model.transcribe(
80
- audio_path,
81
- language="en",
82
- task="transcribe",
83
- beam_size=beam_size,
84
- temperature=temperature,
85
- word_timestamps=False,
86
- )
87
-
88
- def expand_abbrev(text):
89
  for k, v in ABBREV_MAP.items():
90
  text = text.replace(k, v).replace(k.upper(), v)
91
  return text
92
 
93
- def medical_lexicon_correct(text, score_cutoff=70):
 
 
 
 
94
  if not MED_VOCAB:
95
  return text
96
  words = text.split()
@@ -100,7 +75,7 @@ def medical_lexicon_correct(text, score_cutoff=70):
100
  matched = False
101
  for n in (4, 3, 2):
102
  if i + n <= len(words):
103
- phrase = " ".join(words[i:i+n])
104
  res = process.extractOne(phrase, MED_VOCAB, scorer=fuzz.token_sort_ratio, score_cutoff=score_cutoff)
105
  if res:
106
  cand, score = res[0], res[1]
@@ -119,7 +94,7 @@ def medical_lexicon_correct(text, score_cutoff=70):
119
  i += 1
120
  return " ".join(out_words)
121
 
122
- def apply_postprocessing(text):
123
  text = text.strip()
124
  text = expand_abbrev(text)
125
  text = medical_lexicon_correct(text)
@@ -129,89 +104,161 @@ def apply_postprocessing(text):
129
  text = text + "."
130
  return text
131
 
132
- def redact_phi(text):
133
  import re
134
  text = re.sub(r"\b(\d{3}-\d{2}-\d{4})\b", "[REDACTED_SSN]", text)
135
  text = re.sub(r"\b(\d{2}\/\d{2}\/\d{4})\b", "[REDACTED_DATE]", text)
136
  text = re.sub(r"\b(patient|pt)\s+\d+\b", "[REDACTED_PATIENT_ID]", text, flags=re.IGNORECASE)
137
  return text
138
 
139
- def transcribe_file(model, audio_path, redact=True):
140
- logger.info("Transcribing %s on device=%s model=%s", audio_path, DEVICE, MODEL_NAME)
 
 
 
141
  raw = whisper_transcribe(model, audio_path)
142
  text = raw.get("text", "").strip()
143
  proc = apply_postprocessing(text)
144
  if redact:
145
  proc = redact_phi(proc)
146
- return {"raw_text": text, "postprocessed_text": proc, "segments": raw.get("segments", []), "language": raw.get("language", None)}
147
-
148
- # Load model once (when server starts or when CLI invoked)
149
- def load_model(name=MODEL_NAME):
150
- logger.info("Loading model %s on %s (this may take a while)...", name, DEVICE)
151
- m = whisper.load_model(name, device=DEVICE)
152
- logger.info("Model loaded.")
153
- return m
154
-
155
- # CLI mode: if audio argument provided, transcribe and exit
156
- def cli_mode(model, argv):
157
- import argparse
158
- parser = argparse.ArgumentParser(description="Medical transcription using Whisper")
159
- parser.add_argument("audio", help="Path to audio file to transcribe (wav, mp3, m4a...)")
160
- parser.add_argument("--no-redact", action="store_true", help="Disable automatic PHI redaction")
161
- parser.add_argument("--output", "-o", default="transcript.json", help="Output JSON file")
162
- args = parser.parse_args(argv)
163
-
164
- audio_path = args.audio
165
- if not Path(audio_path).exists():
166
- logger.error("Audio file not found: %s", audio_path)
167
- sys.exit(2)
168
-
169
- out = transcribe_file(model, audio_path, redact=not args.no_redact)
170
- with open(args.output, "w", encoding="utf-8") as f:
171
- json.dump(out, f, indent=2)
172
- logger.info("Saved transcript to %s", args.output)
173
- print(out["postprocessed_text"])
174
-
175
- # Server mode: small Flask app for uploads
176
- app = Flask(__name__)
177
- MODEL = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- @app.route("/health", methods=["GET"])
180
  def health():
181
- return jsonify({"ok": True, "model": MODEL_NAME, "device": DEVICE})
 
 
 
 
182
 
183
  @app.route("/transcribe", methods=["POST"])
184
  def transcribe_endpoint():
185
- if "file" not in request.files:
186
- return jsonify({"error": "missing file in form-data (name=file)"}), 400
187
- f = request.files["file"]
188
- if f.filename == "":
189
- return jsonify({"error": "empty filename"}), 400
190
- fd, tmp = tempfile.mkstemp(suffix=Path(f.filename).suffix)
191
- os.close(fd)
192
- f.save(tmp)
 
 
 
 
193
  try:
194
- out = transcribe_file(MODEL, tmp, redact=True)
195
- return jsonify(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  finally:
197
- try:
198
- os.remove(tmp)
199
- except Exception:
200
- pass
201
-
202
- def main():
203
- global MODEL
204
- # If arguments present -> CLI mode
205
- if len(sys.argv) > 1:
206
- MODEL = load_model()
207
- cli_mode(MODEL, sys.argv[1:])
208
- return
 
209
 
210
- # No args -> server mode
211
- MODEL = load_model()
212
- # start server
213
- logger.info("Starting server on 0.0.0.0:5000")
214
- app.run(host="0.0.0.0", port=5000, debug=False)
215
 
216
  if __name__ == "__main__":
217
- main()
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Quick-start multi-file transcription -> merged DOCX using Whisper (small by default).
4
 
5
+ - GET / -> simple upload UI
6
+ - GET /health -> liveness (always 200)
7
+ - GET /ready -> readiness (503 until model loaded)
8
+ - POST /transcribe -> accept files field (multiple) and return merged docx
 
9
 
10
+ This app loads the Whisper model in a background thread so the server becomes responsive fast.
11
+ """
12
  import os
13
+ import sys
14
+ import tempfile
15
+ import threading
16
+ import time
17
  import logging
18
  from pathlib import Path
19
+ from typing import List, Dict, Any
20
+
21
+ from flask import Flask, request, jsonify, send_file, render_template
22
 
 
23
  logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger("med-asr-quick")
25
+
26
+ # Config
27
+ MODEL_NAME = os.environ.get("WHISPER_MODEL", "small")
28
+ PORT = int(os.environ.get("PORT", 5000))
29
+ MAX_FILES = 20
30
+ ALLOWED_EXT = {".wav", ".mp3", ".m4a", ".flac", ".aac", ".ogg"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  MED_VOCAB_PATH = Path("medical_vocab.txt")
32
+
33
+ # Postprocessing maps
34
  ABBREV_MAP = {
35
  "h/o": "history of",
36
  "c/o": "complains of",
 
41
  "bp": "blood pressure",
42
  }
43
 
44
+ # Readiness globals
45
+ READY = False
46
+ MODEL = None
47
+ MODEL_LOCK = threading.Lock()
48
+
49
+ # Lazy-load med vocab
50
  def load_med_vocab(path=MED_VOCAB_PATH):
51
  if not path.exists():
52
  return []
 
55
 
56
  MED_VOCAB = load_med_vocab()
57
 
58
+ # Postprocessing helpers
59
+ def expand_abbrev(text: str) -> str:
 
 
 
 
 
 
 
 
 
60
  for k, v in ABBREV_MAP.items():
61
  text = text.replace(k, v).replace(k.upper(), v)
62
  return text
63
 
64
+ def medical_lexicon_correct(text: str, score_cutoff: int = 70) -> str:
65
+ try:
66
+ from rapidfuzz import process, fuzz
67
+ except Exception:
68
+ return text
69
  if not MED_VOCAB:
70
  return text
71
  words = text.split()
 
75
  matched = False
76
  for n in (4, 3, 2):
77
  if i + n <= len(words):
78
+ phrase = " ".join(words[i : i + n])
79
  res = process.extractOne(phrase, MED_VOCAB, scorer=fuzz.token_sort_ratio, score_cutoff=score_cutoff)
80
  if res:
81
  cand, score = res[0], res[1]
 
94
  i += 1
95
  return " ".join(out_words)
96
 
97
+ def apply_postprocessing(text: str) -> str:
98
  text = text.strip()
99
  text = expand_abbrev(text)
100
  text = medical_lexicon_correct(text)
 
104
  text = text + "."
105
  return text
106
 
107
+ def redact_phi(text: str) -> str:
108
  import re
109
  text = re.sub(r"\b(\d{3}-\d{2}-\d{4})\b", "[REDACTED_SSN]", text)
110
  text = re.sub(r"\b(\d{2}\/\d{2}\/\d{4})\b", "[REDACTED_DATE]", text)
111
  text = re.sub(r"\b(patient|pt)\s+\d+\b", "[REDACTED_PATIENT_ID]", text, flags=re.IGNORECASE)
112
  return text
113
 
114
+ # Transcription using whisper
115
+ def whisper_transcribe(model, audio_path: str, beam_size: int = 5, temperature: float = 0.0) -> Dict[str, Any]:
116
+ return model.transcribe(audio_path, language="en", task="transcribe", beam_size=beam_size, temperature=temperature)
117
+
118
+ def transcribe_single(model, audio_path: str, redact: bool = True) -> Dict[str, Any]:
119
  raw = whisper_transcribe(model, audio_path)
120
  text = raw.get("text", "").strip()
121
  proc = apply_postprocessing(text)
122
  if redact:
123
  proc = redact_phi(proc)
124
+ return {"raw_text": text, "postprocessed_text": proc, "segments": raw.get("segments", [])}
125
+
126
+ def make_docx(trans_results: List[Dict[str, Any]], out_path: str):
127
+ from docx import Document
128
+ from docx.shared import Pt
129
+ doc = Document()
130
+ doc.styles["Normal"].font.name = "Arial"
131
+ doc.styles["Normal"].font.size = Pt(11)
132
+ doc.add_heading("Merged Transcripts", level=1)
133
+ for r in trans_results:
134
+ filename = r.get("filename", "unknown")
135
+ doc.add_heading(f"{filename}", level=2)
136
+ meta = r.get("meta", {})
137
+ if meta:
138
+ doc.add_paragraph(f"Duration: {meta.get('duration', 'unknown')}s")
139
+ doc.add_heading("Postprocessed Transcript", level=3)
140
+ doc.add_paragraph(r.get("postprocessed_text", ""))
141
+ doc.add_heading("Raw Transcript", level=3)
142
+ doc.add_paragraph(r.get("raw_text", ""))
143
+ segments = r.get("segments", [])
144
+ if segments:
145
+ doc.add_heading("Segments (timestamps)", level=4)
146
+ for seg in segments:
147
+ start = seg.get("start", 0)
148
+ end = seg.get("end", 0)
149
+ text = seg.get("text", "").strip()
150
+ doc.add_paragraph(f"[{start:.2f} - {end:.2f}] {text}")
151
+ doc.add_page_break()
152
+ doc.save(out_path)
153
+
154
+ def validate_and_save_files(files_list) -> List[Path]:
155
+ saved_paths = []
156
+ if not files_list:
157
+ return []
158
+ if len(files_list) > MAX_FILES:
159
+ raise ValueError(f"Too many files (max {MAX_FILES}).")
160
+ for f in files_list:
161
+ filename = f.filename
162
+ if not filename:
163
+ continue
164
+ ext = Path(filename).suffix.lower()
165
+ if ext not in ALLOWED_EXT:
166
+ raise ValueError(f"Unsupported file extension: {ext}")
167
+ fd, tmp = tempfile.mkstemp(suffix=ext)
168
+ os.close(fd)
169
+ f.save(tmp)
170
+ saved_paths.append(Path(tmp))
171
+ return saved_paths
172
+
173
+ # Background model loader
174
+ def _load_model_background(name=MODEL_NAME):
175
+ global MODEL, READY
176
+ try:
177
+ import torch
178
+ import whisper
179
+ except Exception as e:
180
+ logger.exception("Failed to import heavy libs: %s", e)
181
+ READY = False
182
+ return
183
+ try:
184
+ logger.info("Loading Whisper model %s in background...", name)
185
+ m = whisper.load_model(name, device="cuda" if torch.cuda.is_available() else "cpu")
186
+ with MODEL_LOCK:
187
+ MODEL = m
188
+ READY = True
189
+ logger.info("Model loaded.")
190
+ except Exception:
191
+ logger.exception("Background model loading failed.")
192
+ READY = False
193
+
194
+ def start_background_loader():
195
+ t = threading.Thread(target=_load_model_background, daemon=True)
196
+ t.start()
197
+
198
+ # Flask app
199
+ app = Flask(__name__, template_folder="templates", static_folder="static")
200
+
201
+ @app.route("/")
202
+ def index():
203
+ return render_template("index.html", max_files=MAX_FILES)
204
 
205
+ @app.route("/health")
206
  def health():
207
+ return jsonify({"ok": True})
208
+
209
+ @app.route("/ready")
210
+ def ready():
211
+ return jsonify({"ready": READY, "model": MODEL_NAME}), (200 if READY else 503)
212
 
213
  @app.route("/transcribe", methods=["POST"])
214
  def transcribe_endpoint():
215
+ global READY, MODEL
216
+ if not READY or MODEL is None:
217
+ return jsonify({"error": "model not ready"}), 503
218
+ files = request.files.getlist("files")
219
+ if not files:
220
+ return jsonify({"error": "no files uploaded (use form field name 'files')"}), 400
221
+ try:
222
+ saved_paths = validate_and_save_files(files)
223
+ except ValueError as e:
224
+ return jsonify({"error": str(e)}), 400
225
+
226
+ trans_results = []
227
  try:
228
+ for p in saved_paths:
229
+ logger.info("Transcribing %s", p)
230
+ try:
231
+ trans = transcribe_single(MODEL, str(p), redact=True)
232
+ # duration meta
233
+ try:
234
+ import soundfile as sf
235
+ info = sf.info(str(p))
236
+ duration = round(info.duration, 2) if info.duration else None
237
+ except Exception:
238
+ duration = None
239
+ trans_results.append({"filename": p.name, "meta": {"duration": duration}, **trans})
240
+ except Exception:
241
+ logger.exception("Error transcribing %s", p)
242
+ trans_results.append({"filename": p.name, "error": "transcription failed"})
243
  finally:
244
+ for p in saved_paths:
245
+ try:
246
+ os.remove(p)
247
+ except Exception:
248
+ pass
249
+
250
+ fd, out_tmp = tempfile.mkstemp(suffix=".docx")
251
+ os.close(fd)
252
+ try:
253
+ make_docx(trans_results, out_tmp)
254
+ except Exception:
255
+ logger.exception("Failed to create docx")
256
+ return jsonify({"error": "failed to generate docx"}), 500
257
 
258
+ return send_file(out_tmp, as_attachment=True, download_name="merged_transcripts.docx")
 
 
 
 
259
 
260
  if __name__ == "__main__":
261
+ # If CLI args provided, not supported here; run server
262
+ start_background_loader()
263
+ logger.info("Starting server on 0.0.0.0:%d (model loads in background)", PORT)
264
+ app.run(host="0.0.0.0", port=PORT, debug=False)