rjzevallos commited on
Commit
88e4729
·
1 Parent(s): bcc3253

Feat(server): support WHISPER_MODEL_NAME/WHISPER_MODEL_SIZE (e.g. tiny); prefer local <name>.pt; improve error guidance

Browse files
Files changed (1) hide show
  1. server_wrapper.py +44 -24
server_wrapper.py CHANGED
@@ -15,29 +15,42 @@ _online = None
15
 
16
 
17
  def _get_model_path():
18
- """Get the path to the Whisper model. Download if needed."""
 
 
 
 
 
 
 
19
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  model_dir = os.path.expanduser('~/.cache/whisper')
21
-
22
- # Create model directory if it doesn't exist
23
- if not os.path.exists(model_dir):
24
- os.makedirs(model_dir, exist_ok=True)
25
-
26
- model_path = os.path.join(model_dir, 'large-v3.pt')
27
-
28
- if not os.path.exists(model_path):
29
- print(f"Model not found at {model_path}. Downloading...")
30
- try:
31
- import whisper
32
- whisper.load_model('large-v3')
33
- print(f"Model downloaded to {model_path}")
34
- except Exception as e:
35
- print(f"Warning: Could not download model: {e}")
36
- # Fallback to trying current directory
37
- if os.path.exists('./large-v3.pt'):
38
- return './large-v3.pt'
39
-
40
- return model_path
41
 
42
  def _make_args():
43
  # Minimal args required by simul_asr_factory
@@ -68,9 +81,16 @@ def init_model():
68
  with _lock:
69
  if _initialized:
70
  return
71
- args = _make_args()
72
- _asr, _online = simul_asr_factory(args)
73
- _initialized = True
 
 
 
 
 
 
 
74
 
75
 
76
  def reset():
 
15
 
16
 
17
  def _get_model_path():
18
+ """Get the path to the Whisper model.
19
+
20
+ Behavior:
21
+ - Prefer `WHISPER_MODEL_PATH` env var if provided.
22
+ - Otherwise prefer `./large-v3.pt` (repo-local file) or cached `~/.cache/whisper/large-v3.pt`.
23
+ - Do NOT attempt to download the model automatically (downloading at runtime can hang Spaces).
24
+ - If not found, raise FileNotFoundError with guidance.
25
+ """
26
  import os
27
+
28
+ # allow user to override with env var path
29
+ env_path = os.environ.get('WHISPER_MODEL_PATH') or os.environ.get('MODEL_PATH')
30
+ if env_path:
31
+ if os.path.exists(env_path):
32
+ return env_path
33
+ else:
34
+ raise FileNotFoundError(f"WHISPER_MODEL_PATH is set but file not found: {env_path}")
35
+
36
+ # allow user to request a model name/size (e.g. 'tiny', 'base', 'large-v3')
37
+ model_name = os.environ.get('WHISPER_MODEL_NAME') or os.environ.get('WHISPER_MODEL_SIZE') or 'large-v3'
38
+
39
+ # check local repo file first (e.g. ./tiny.pt or ./large-v3.pt)
40
+ local_path = f'./{model_name}.pt'
41
+ if os.path.exists(local_path):
42
+ return local_path
43
+
44
+ # check cache path (pre-downloaded by build or other process)
45
  model_dir = os.path.expanduser('~/.cache/whisper')
46
+ model_path = os.path.join(model_dir, f'{model_name}.pt')
47
+ if os.path.exists(model_path):
48
+ return model_path
49
+
50
+ # Do not attempt to download automatically in runtime.
51
+ raise FileNotFoundError(
52
+ 'Whisper model not found. Set WHISPER_MODEL_PATH to a local model file, or set WHISPER_MODEL_NAME to a model name (e.g. tiny) and pre-download the corresponding "<name>.pt" file into the repo or ~/.cache/whisper/.'
53
+ )
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def _make_args():
56
  # Minimal args required by simul_asr_factory
 
81
  with _lock:
82
  if _initialized:
83
  return
84
+ try:
85
+ args = _make_args()
86
+ _asr, _online = simul_asr_factory(args)
87
+ _initialized = True
88
+ except FileNotFoundError as e:
89
+ print(f"Model initialization aborted: {e}")
90
+ # leave _initialized False so callers know model not ready
91
+ except Exception as e:
92
+ print(f"Unexpected error initializing model: {e}")
93
+ # don't raise here; allow the app to continue running without model
94
 
95
 
96
  def reset():