File size: 4,437 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Transcribe one or more audio files using a fine-tuned (or base) Whisper model.

Usage:
    # Use the fine-tuned model
    python scripts/transcribe.py audio.mp3

    # Transcribe multiple files
    python scripts/transcribe.py file1.mp3 file2.wav

    # Use a different model (HF model ID or local path)
    python scripts/transcribe.py --model openai/whisper-large-v3 audio.mp3
    python scripts/transcribe.py --model outputs/checkpoints/best_model audio.mp3

    # Save output to a file
    python scripts/transcribe.py audio.mp3 --output result.txt
"""

from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path
import os
from dotenv import load_dotenv

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.inference.transcribe import WhisperTranscriber

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)

DEFAULT_MODEL = "outputs/checkpoints/best_model"


def main():
    root = Path(__file__).parent.parent
    load_dotenv(root / ".env")

    parser = argparse.ArgumentParser(description="Transcribe Arabic audio with Whisper")
    parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
    parser.add_argument(
        "--model",
        default=str(root / DEFAULT_MODEL),
        help=f"Model path or HF model ID (default: {DEFAULT_MODEL})",
    )
    parser.add_argument("--output", default=None, help="Write transcription to this file")
    parser.add_argument("--device", default=None, help="cuda or cpu (auto-detect if omitted)")
    parser.add_argument("--diarize", action="store_true", help="Use Pyannote to diarize conversation turns")
    parser.add_argument("--hf-token", default=os.environ.get("HF_TOKEN"), help="HuggingFace token for Pyannote")
    parser.add_argument("--analyze", action="store_true", help="Post-process transcript with Gemini 2.5 Flash")
    parser.add_argument("--api-key", default=None, help="Gemini API Key (or set GEMINI_API_KEY env var)")
    args = parser.parse_args()

    model_path = args.model
    # Fall back to base model if fine-tuned one doesn't exist yet
    if not Path(model_path).exists() and not model_path.startswith("openai/"):
        logger.warning(
            "Fine-tuned model not found at %s — falling back to openai/whisper-large-v3",
            model_path,
        )
        model_path = "openai/whisper-large-v3"

    transcriber = WhisperTranscriber(model_path=model_path, device=args.device)

    analyzer = None
    if args.analyze:
        try:
            from src.inference.analyze_call import CallAnalyzer
            analyzer = CallAnalyzer(api_key=args.api_key)
            logger.info("CallAnalyzer initialized with Gemini 2.5 Flash.")
        except Exception as e:
            logger.error("Failed to initialize CallAnalyzer: %s", e)
            sys.exit(1)

    results = []
    for audio_path in args.audio:
        logger.info("Transcribing %s ...", audio_path)
        if args.diarize:
            if not args.hf_token:
                logger.error("--hf-token or HF_TOKEN in .env is required for diarization")
                sys.exit(1)
            text = transcriber.transcribe_with_diarization(audio_path, args.hf_token)
        else:
            text = transcriber.transcribe(audio_path)
        
        analysis_result = None
        if analyzer:
            logger.info("Analyzing transcript for %s ...", audio_path)
            try:
                analysis = analyzer.analyze(text)
                analysis_result = analysis.model_dump_json(indent=2)
            except Exception as e:
                logger.error("Failed to analyze transcript: %s", e)
                
        results.append((audio_path, text, analysis_result))
        
        print(f"\n=== {Path(audio_path).name} ===")
        print(f"Raw Transcript:\n{text}")
        if analysis_result:
            print(f"\nAnalysis:\n{analysis_result}")

    if args.output:
        out_path = Path(args.output)
        with out_path.open("w", encoding="utf-8") as fh:
            for path, text, analysis in results:
                fh.write(f"=== {Path(path).name} ===\nRaw Transcript:\n{text}\n\n")
                if analysis:
                    fh.write(f"Analysis:\n{analysis}\n\n")
        logger.info("Output written to %s", out_path)


if __name__ == "__main__":
    main()