LRU1 commited on
Commit
9c13b61
·
1 Parent(s): c50abfa

add non-local whisper api

Browse files
lec2note/ingestion/whisper_runner.py CHANGED
@@ -8,9 +8,9 @@ logger = logging.getLogger(__name__)
8
 
9
  from typing import List, Dict, Optional, Any
10
 
11
- import torch
12
  from whisper import load_model # type: ignore
13
- import json
14
 
15
  __all__ = ["WhisperRunner"]
16
 
@@ -31,13 +31,29 @@ class WhisperRunner: # noqa: D101
31
  if not audio_path.exists():
32
  raise FileNotFoundError(audio_path)
33
 
34
- device = "cuda" if torch.cuda.is_available() else "cpu"
35
- logger.info("[Whisper] loading model %s on %s", cls.model_name, device)
36
- model = load_model(cls.model_name, device=device)
37
 
38
- logger.info("[Whisper] transcribing %s", audio_path.name)
39
- result = model.transcribe(str(audio_path), language=lang)
40
- segments = result.get("segments", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # convert to our schema
43
  logger.info("[Whisper] got %d segments", len(segments))
 
8
 
9
  from typing import List, Dict, Optional, Any
10
 
11
+ import torch, json, os
12
  from whisper import load_model # type: ignore
13
+ from openai import OpenAI
14
 
15
  __all__ = ["WhisperRunner"]
16
 
 
31
  if not audio_path.exists():
32
  raise FileNotFoundError(audio_path)
33
 
34
+ use_local = os.getenv("AUDIO2TEXT_LOCAL", "true").lower() != "false"
 
 
35
 
36
+ if use_local:
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ logger.info("[Whisper] loading model %s on %s", cls.model_name, device)
39
+ model = load_model(cls.model_name, device=device)
40
+
41
+ logger.info("[Whisper] transcribing %s (local)", audio_path.name)
42
+ result = model.transcribe(str(audio_path), language=lang)
43
+ else:
44
+ # remote API mode
45
+ api_base = os.getenv("AIHUB_API_BASE")
46
+ api_key = os.getenv("AIHUB_API_KEY")
47
+ if not api_key:
48
+ raise EnvironmentError("AIHUB_API_KEY not set")
49
+
50
+ client = OpenAI(api_key=api_key, base_url=api_base)
51
+ logger.info("[Whisper] uploading %s to API (whisper-large-v3)", audio_path.name)
52
+ with audio_path.open("rb") as f:
53
+ resp = client.audio.transcriptions.create(model="whisper-large-v3", file=f, language=lang)
54
+ # resp.text contains full text, but we need segments; assume API returns segments list if 'json' format
55
+ segments = resp.segments if hasattr(resp, "segments") else [{"start": 0.0, "end": 0.0, "text": resp.text}]
56
+ result = {"segments": segments}
57
 
58
  # convert to our schema
59
  logger.info("[Whisper] got %d segments", len(segments))