Spaces:
Running
Running
Update preload_model.py
Browse files- preload_model.py +10 -15
preload_model.py
CHANGED
|
@@ -8,32 +8,27 @@ import torch
|
|
| 8 |
from kokoro import KModel, KPipeline
|
| 9 |
import logging
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
logging.basicConfig(level=logging.INFO)
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
def preload_model():
|
| 15 |
"""Pre-load the model to simple cache directory"""
|
| 16 |
-
|
| 17 |
-
os.makedirs(cache_dir, exist_ok=True)
|
| 18 |
-
|
| 19 |
-
# CRITICAL: Set cache environment variables FIRST
|
| 20 |
-
os.environ['HF_HOME'] = cache_dir
|
| 21 |
-
os.environ['HUGGINGFACE_HUB_CACHE'] = cache_dir
|
| 22 |
-
os.environ['HF_HUB_CACHE'] = cache_dir
|
| 23 |
-
|
| 24 |
-
logger.info(f"Pre-loading Kokoro-82M model to {cache_dir}...")
|
| 25 |
|
| 26 |
try:
|
| 27 |
-
# Initialize model
|
| 28 |
logger.info("Downloading model...")
|
| 29 |
-
model = KModel(
|
| 30 |
-
repo_id='hexgrad/Kokoro-82M',
|
| 31 |
-
cache_dir=cache_dir
|
| 32 |
-
).to('cpu').eval()
|
| 33 |
|
| 34 |
# Initialize pipeline
|
| 35 |
logger.info("Initializing pipeline...")
|
| 36 |
-
pipeline = KPipeline(lang_code='a', model=False
|
| 37 |
|
| 38 |
# Test with a small text
|
| 39 |
test_text = "Hello world"
|
|
|
|
| 8 |
from kokoro import KModel, KPipeline
|
| 9 |
import logging
|
| 10 |
|
| 11 |
+
# SET CACHE ENVIRONMENT VARIABLES AT MODULE LEVEL - BEFORE ANY IMPORTS
|
| 12 |
+
CACHE_DIR = "/app/cache"
|
| 13 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 14 |
+
os.environ['HF_HOME'] = CACHE_DIR
|
| 15 |
+
os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
|
| 16 |
+
|
| 17 |
logging.basicConfig(level=logging.INFO)
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
def preload_model():
|
| 21 |
"""Pre-load the model to simple cache directory"""
|
| 22 |
+
logger.info(f"Pre-loading Kokoro-82M model to {CACHE_DIR}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
try:
|
| 25 |
+
# Initialize model (no cache_dir parameter needed)
|
| 26 |
logger.info("Downloading model...")
|
| 27 |
+
model = KModel(repo_id='hexgrad/Kokoro-82M').to('cpu').eval()
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Initialize pipeline
|
| 30 |
logger.info("Initializing pipeline...")
|
| 31 |
+
pipeline = KPipeline(lang_code='a', model=False)
|
| 32 |
|
| 33 |
# Test with a small text
|
| 34 |
test_text = "Hello world"
|