Rajhuggingface4253 commited on
Commit
a3982b2
·
verified ·
1 Parent(s): 4203773

Create preload_neutts.py

Browse files
Files changed (1) hide show
  1. preload_neutts.py +36 -0
preload_neutts.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ # Set cache environment variables before any other imports
5
+ CACHE_DIR = "/app/cache"
6
+ os.environ['HF_HOME'] = CACHE_DIR
7
+ os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
8
+
9
+ # Now import the model class
10
+ from neuttsair.neutts import NeuTTSAir
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def preload_model():
16
+ """
17
+ Downloads and caches the NeuTTSAir model and its dependencies
18
+ to the directory specified by HF_HOME.
19
+ """
20
+ logger.info(f"Pre-loading NeuTTS Air model to cache: {CACHE_DIR}...")
21
+ try:
22
+ # Instantiating the class triggers the download from Hugging Face Hub
23
+ NeuTTSAir(
24
+ backbone_repo="neuphonic/neutts-air",
25
+ backbone_device="cpu",
26
+ codec_repo="neuphonic/neucodec",
27
+ codec_device="cpu"
28
+ )
29
+ logger.info("✅ NeuTTS Air model pre-loading completed successfully!")
30
+ except Exception as e:
31
+ logger.error(f"❌ Error during model pre-loading: {e}")
32
+ # We raise the exception to fail the build if the model can't be downloaded
33
+ raise e
34
+
35
+ if __name__ == "__main__":
36
+ preload_model()