|
|
|
|
|
""" |
|
|
Trans-for-Doctors CLI |
|
|
|
|
|
Runs the end-to-end pipeline: STT → Knowledge Base → LLM Correction → (optional) DOCX report. |
|
|
|
|
|
Usage examples: |
|
|
uv run transmed --audio path/to.wav --model . --llm --generate-report |
|
|
uv run transmed --audio path/to.wav --model . --no-llm |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
from pipeline import MedicalTranscriptionPipeline, PipelineConfig |
|
|
|
|
|
|
|
|
def setup_logging(level: str = "INFO") -> None: |
|
|
logging.basicConfig( |
|
|
level=getattr(logging, level.upper(), logging.INFO), |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Run medical transcription pipeline (STT + LLM Corrector + KB)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument("--audio", required=True, type=str, help="Path to audio .wav file") |
|
|
parser.add_argument("--model", type=str, default=".", help="Path to Whisper model directory") |
|
|
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu", "mps"], help="Inference device") |
|
|
parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"], help="Torch dtype") |
|
|
parser.add_argument("--language", type=str, default="russian", help="Transcription language") |
|
|
|
|
|
|
|
|
parser.add_argument("--terms", type=str, default="medical_terms.txt", help="Path to medical terms file") |
|
|
|
|
|
|
|
|
parser.add_argument("--llm", dest="llm", action="store_true", help="Enable LLM correction") |
|
|
parser.add_argument("--no-llm", dest="llm", action="store_false", help="Disable LLM correction") |
|
|
parser.set_defaults(llm=True) |
|
|
parser.add_argument("--openai-model", type=str, default="gpt-4o", help="OpenAI model name") |
|
|
parser.add_argument("--openai-key", type=str, default=os.getenv("OPENAI_API_KEY"), help="OpenAI API key (defaults to env OPENAI_API_KEY)") |
|
|
|
|
|
|
|
|
parser.add_argument("--save-original", action="store_true", help="Save original transcription JSON") |
|
|
parser.add_argument("--save-corrected", action="store_true", help="Save corrected transcription JSON") |
|
|
parser.add_argument("--generate-report", action="store_true", help="Generate DOCX report") |
|
|
parser.add_argument("--results-dir", type=str, default="results", help="Directory to store results") |
|
|
parser.add_argument("--logs-dir", type=str, default="logs", help="Directory to store logs") |
|
|
|
|
|
|
|
|
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level") |
|
|
|
|
|
|
|
|
parser.add_argument("--patient-name", type=str, default=None) |
|
|
parser.add_argument("--patient-id", type=str, default=None) |
|
|
parser.add_argument("--study-date", type=str, default=None) |
|
|
parser.add_argument("--modality", type=str, default=None) |
|
|
parser.add_argument("--body-part", type=str, default=None) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
setup_logging(args.log_level) |
|
|
logger = logging.getLogger("transmed") |
|
|
|
|
|
audio_path = Path(args.audio) |
|
|
model_path = Path(args.model) |
|
|
terms_path = Path(args.terms) |
|
|
results_dir = Path(args.results_dir) |
|
|
logs_dir = Path(args.logs_dir) |
|
|
|
|
|
if not audio_path.exists(): |
|
|
logger.error(f"Audio file not found: {audio_path}") |
|
|
raise SystemExit(1) |
|
|
if not model_path.exists(): |
|
|
logger.error(f"Model path not found: {model_path}") |
|
|
raise SystemExit(1) |
|
|
if not terms_path.exists(): |
|
|
logger.warning(f"Terms file not found: {terms_path} — proceeding without extra terms") |
|
|
|
|
|
|
|
|
config = PipelineConfig( |
|
|
model_path=model_path, |
|
|
device=args.device, |
|
|
dtype=args.dtype, |
|
|
language=args.language, |
|
|
medical_terms_file=terms_path, |
|
|
openai_api_key=args.openai_key, |
|
|
openai_model=args.openai_model, |
|
|
correction_enabled=args.llm, |
|
|
save_original=args.save_original, |
|
|
save_corrected=args.save_corrected, |
|
|
save_diff=True, |
|
|
generate_report=args.generate_report, |
|
|
results_dir=results_dir, |
|
|
reports_dir=results_dir / "reports", |
|
|
logs_dir=logs_dir, |
|
|
) |
|
|
|
|
|
logger.info("Creating medical transcription pipeline...") |
|
|
pipeline = MedicalTranscriptionPipeline(config) |
|
|
|
|
|
patient_metadata = None |
|
|
if args.generate_report: |
|
|
patient_metadata = { |
|
|
"patient_name": args.patient_name, |
|
|
"patient_id": args.patient_id, |
|
|
"study_date": args.study_date, |
|
|
"modality": args.modality, |
|
|
"body_part": args.body_part, |
|
|
} |
|
|
|
|
|
logger.info(f"Processing audio: {audio_path.name}") |
|
|
result = pipeline.process_audio_file(audio_path=audio_path, patient_metadata=patient_metadata) |
|
|
|
|
|
if result.get("status") != "success": |
|
|
logger.error(f"Pipeline failed: {result.get('error')}") |
|
|
raise SystemExit(2) |
|
|
|
|
|
|
|
|
orig = result.get("original_transcription", "") |
|
|
corr = result.get("corrected_transcription", orig) |
|
|
logger.info(f"Original ({len(orig)} chars): {orig[:200]}...") |
|
|
if config.correction_enabled: |
|
|
logger.info(f"Corrected ({len(corr)} chars): {corr[:200]}...") |
|
|
logger.info(f"Corrections: {len(result.get('corrections', []))}") |
|
|
if result.get("report_path"): |
|
|
logger.info(f"Report: {result['report_path']}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|