NitinBot001 commited on
Commit
4c8dce9
Β·
verified Β·
1 Parent(s): 3d1ea91

Create download_models.py

Browse files
Files changed (1) hide show
  1. download_models.py +57 -0
download_models.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model download script for Bark TTS and Whisper ASR
3
+ This script downloads and caches the models during Docker build
4
+ """
5
+ import torch
6
+ from transformers import AutoProcessor, BarkModel
7
+ import whisper
8
+ import os
9
+
10
+ def download_bark_model():
11
+ """Download Bark TTS model"""
12
+ print('Downloading Bark model...')
13
+ try:
14
+ processor = AutoProcessor.from_pretrained('suno/bark-small')
15
+ model = BarkModel.from_pretrained('suno/bark-small')
16
+ print('βœ… Bark model downloaded successfully')
17
+ return True
18
+ except Exception as e:
19
+ print(f'❌ Error downloading Bark model: {e}')
20
+ return False
21
+
22
+ def download_whisper_model():
23
+ """Download Whisper ASR model"""
24
+ print('Downloading Whisper model...')
25
+ try:
26
+ whisper_model = whisper.load_model('base')
27
+ print('βœ… Whisper model downloaded successfully')
28
+ return True
29
+ except Exception as e:
30
+ print(f'❌ Error downloading Whisper model: {e}')
31
+ return False
32
+
33
+ def main():
34
+ """Main function to download all models"""
35
+ print('πŸš€ Starting model download process...')
36
+ print('-' * 50)
37
+
38
+ bark_success = download_bark_model()
39
+ print()
40
+
41
+ whisper_success = download_whisper_model()
42
+ print()
43
+
44
+ if bark_success and whisper_success:
45
+ print('πŸŽ‰ All models downloaded successfully!')
46
+ else:
47
+ print('⚠️ Some models failed to download. Check logs above.')
48
+ if not bark_success:
49
+ print(' - Bark model download failed')
50
+ if not whisper_success:
51
+ print(' - Whisper model download failed')
52
+
53
+ print('-' * 50)
54
+ print('Model download process complete!')
55
+
56
+ if __name__ == "__main__":
57
+ main()