Update modeling_spark_tts.py
Browse files- modeling_spark_tts.py +25 -24
modeling_spark_tts.py
CHANGED
|
@@ -3092,31 +3092,32 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
|
|
| 3092 |
# Check for trust_remote_code - needed for config loading if custom code involved there too
|
| 3093 |
trust_remote_code = model_kwargs.pop("trust_remote_code", False) # Important
|
| 3094 |
|
| 3095 |
-
|
| 3096 |
-
|
| 3097 |
-
|
| 3098 |
-
|
| 3099 |
-
|
| 3100 |
-
|
| 3101 |
-
|
| 3102 |
-
resolved_model_path = Path(cached_file(
|
| 3103 |
-
pretrained_model_name_or_path,
|
| 3104 |
-
filename=cls.config_class.config_files[0], # e.g., "config.json"
|
| 3105 |
-
cache_dir=cache_dir,
|
| 3106 |
-
force_download=force_download,
|
| 3107 |
-
local_files_only=local_files_only,
|
| 3108 |
-
token=token,
|
| 3109 |
-
revision=revision,
|
| 3110 |
-
_raise_exceptions_for_missing_entries=False,
|
| 3111 |
-
_raise_exceptions_for_connection_errors=False,
|
| 3112 |
-
)).parent
|
| 3113 |
-
except Exception as e:
|
| 3114 |
-
logger.warning(f"Could not resolve cache path for {pretrained_model_name_or_path}: {e}. Assuming it's a local path.")
|
| 3115 |
-
resolved_model_path = Path(pretrained_model_name_or_path) # Fallback
|
| 3116 |
-
if not resolved_model_path.is_dir():
|
| 3117 |
-
raise EnvironmentError(f"Cannot find model directory at {resolved_model_path}")
|
| 3118 |
else:
|
| 3119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3120 |
|
| 3121 |
|
| 3122 |
# Helper function to resolve paths relative to the main model directory
|
|
|
|
| 3092 |
# Check for trust_remote_code - needed for config loading if custom code involved there too
|
| 3093 |
trust_remote_code = model_kwargs.pop("trust_remote_code", False) # Important
|
| 3094 |
|
| 3095 |
+
|
| 3096 |
+
# NEW IMPROVED PATH RESOLUTION
|
| 3097 |
+
from huggingface_hub import snapshot_download
|
| 3098 |
+
import os
|
| 3099 |
+
# Check if it's a local path first
|
| 3100 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 3101 |
+
resolved_model_path = Path(pretrained_model_name_or_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3102 |
else:
|
| 3103 |
+
# Try to get from Hugging Face Hub
|
| 3104 |
+
try:
|
| 3105 |
+
logger.info(f"Downloading/locating model from Hugging Face Hub: {pretrained_model_name_or_path}")
|
| 3106 |
+
# This will download the model if needed and return the cached path
|
| 3107 |
+
resolved_model_path = Path(snapshot_download(
|
| 3108 |
+
pretrained_model_name_or_path,
|
| 3109 |
+
revision=revision,
|
| 3110 |
+
cache_dir=cache_dir,
|
| 3111 |
+
force_download=force_download,
|
| 3112 |
+
local_files_only=local_files_only,
|
| 3113 |
+
token=token,
|
| 3114 |
+
))
|
| 3115 |
+
except Exception as e:
|
| 3116 |
+
logger.error(f"Error downloading model: {e}")
|
| 3117 |
+
raise EnvironmentError(f"Failed to find or download model '{pretrained_model_name_or_path}': {e}")
|
| 3118 |
+
|
| 3119 |
+
if not resolved_model_path.is_dir():
|
| 3120 |
+
raise EnvironmentError(f"Cannot find model directory at {resolved_model_path}")
|
| 3121 |
|
| 3122 |
|
| 3123 |
# Helper function to resolve paths relative to the main model directory
|