add non-local whisper api
Browse files
lec2note/ingestion/whisper_runner.py
CHANGED
|
@@ -8,9 +8,9 @@ logger = logging.getLogger(__name__)
|
|
| 8 |
|
| 9 |
from typing import List, Dict, Optional, Any
|
| 10 |
|
| 11 |
-
import torch
|
| 12 |
from whisper import load_model # type: ignore
|
| 13 |
-
import
|
| 14 |
|
| 15 |
__all__ = ["WhisperRunner"]
|
| 16 |
|
|
@@ -31,13 +31,29 @@ class WhisperRunner: # noqa: D101
|
|
| 31 |
if not audio_path.exists():
|
| 32 |
raise FileNotFoundError(audio_path)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
logger.info("[Whisper] loading model %s on %s", cls.model_name, device)
|
| 36 |
-
model = load_model(cls.model_name, device=device)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# convert to our schema
|
| 43 |
logger.info("[Whisper] got %d segments", len(segments))
|
|
|
|
| 8 |
|
| 9 |
from typing import List, Dict, Optional, Any
|
| 10 |
|
| 11 |
+
import torch, json, os
|
| 12 |
from whisper import load_model # type: ignore
|
| 13 |
+
from openai import OpenAI
|
| 14 |
|
| 15 |
__all__ = ["WhisperRunner"]
|
| 16 |
|
|
|
|
| 31 |
if not audio_path.exists():
|
| 32 |
raise FileNotFoundError(audio_path)
|
| 33 |
|
| 34 |
+
use_local = os.getenv("AUDIO2TEXT_LOCAL", "true").lower() != "false"
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
if use_local:
|
| 37 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
+
logger.info("[Whisper] loading model %s on %s", cls.model_name, device)
|
| 39 |
+
model = load_model(cls.model_name, device=device)
|
| 40 |
+
|
| 41 |
+
logger.info("[Whisper] transcribing %s (local)", audio_path.name)
|
| 42 |
+
result = model.transcribe(str(audio_path), language=lang)
|
| 43 |
+
else:
|
| 44 |
+
# remote API mode
|
| 45 |
+
api_base = os.getenv("AIHUB_API_BASE")
|
| 46 |
+
api_key = os.getenv("AIHUB_API_KEY")
|
| 47 |
+
if not api_key:
|
| 48 |
+
raise EnvironmentError("AIHUB_API_KEY not set")
|
| 49 |
+
|
| 50 |
+
client = OpenAI(api_key=api_key, base_url=api_base)
|
| 51 |
+
logger.info("[Whisper] uploading %s to API (whisper-large-v3)", audio_path.name)
|
| 52 |
+
with audio_path.open("rb") as f:
|
| 53 |
+
resp = client.audio.transcriptions.create(model="whisper-large-v3", file=f, language=lang)
|
| 54 |
+
# resp.text contains full text, but we need segments; assume API returns segments list if 'json' format
|
| 55 |
+
segments = resp.segments if hasattr(resp, "segments") else [{"start": 0.0, "end": 0.0, "text": resp.text}]
|
| 56 |
+
result = {"segments": segments}
|
| 57 |
|
| 58 |
# convert to our schema
|
| 59 |
logger.info("[Whisper] got %d segments", len(segments))
|