ancv commited on
Commit
c394c41
·
verified ·
1 Parent(s): 16a928a

Update modeling_spark_tts.py

Browse files
Files changed (1) hide show
  1. modeling_spark_tts.py +25 -24
modeling_spark_tts.py CHANGED
@@ -3092,31 +3092,32 @@ class SparkTTSModel(PreTrainedModel, GenerationMixin):
3092
  # Check for trust_remote_code - needed for config loading if custom code involved there too
3093
  trust_remote_code = model_kwargs.pop("trust_remote_code", False) # Important
3094
 
3095
- # Determine actual model directory (could be cache path)
3096
- if pretrained_model_name_or_path is not None:
3097
- resolved_model_path = Path(pretrained_model_name_or_path)
3098
- if not resolved_model_path.is_dir():
3099
- # Attempt to download and resolve cache path if it's an ID
3100
- # This requires internet connection if not cached
3101
- try:
3102
- resolved_model_path = Path(cached_file(
3103
- pretrained_model_name_or_path,
3104
- filename=cls.config_class.config_files[0], # e.g., "config.json"
3105
- cache_dir=cache_dir,
3106
- force_download=force_download,
3107
- local_files_only=local_files_only,
3108
- token=token,
3109
- revision=revision,
3110
- _raise_exceptions_for_missing_entries=False,
3111
- _raise_exceptions_for_connection_errors=False,
3112
- )).parent
3113
- except Exception as e:
3114
- logger.warning(f"Could not resolve cache path for {pretrained_model_name_or_path}: {e}. Assuming it's a local path.")
3115
- resolved_model_path = Path(pretrained_model_name_or_path) # Fallback
3116
- if not resolved_model_path.is_dir():
3117
- raise EnvironmentError(f"Cannot find model directory at {resolved_model_path}")
3118
  else:
3119
- raise ValueError("pretrained_model_name_or_path must be provided.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3120
 
3121
 
3122
  # Helper function to resolve paths relative to the main model directory
 
3092
  # Check for trust_remote_code - needed for config loading if custom code involved there too
3093
  trust_remote_code = model_kwargs.pop("trust_remote_code", False) # Important
3094
 
3095
+
3096
+ # NEW IMPROVED PATH RESOLUTION
3097
+ from huggingface_hub import snapshot_download
3098
+ import os
3099
+ # Check if it's a local path first
3100
+ if os.path.isdir(pretrained_model_name_or_path):
3101
+ resolved_model_path = Path(pretrained_model_name_or_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3102
  else:
3103
+ # Try to get from Hugging Face Hub
3104
+ try:
3105
+ logger.info(f"Downloading/locating model from Hugging Face Hub: {pretrained_model_name_or_path}")
3106
+ # This will download the model if needed and return the cached path
3107
+ resolved_model_path = Path(snapshot_download(
3108
+ pretrained_model_name_or_path,
3109
+ revision=revision,
3110
+ cache_dir=cache_dir,
3111
+ force_download=force_download,
3112
+ local_files_only=local_files_only,
3113
+ token=token,
3114
+ ))
3115
+ except Exception as e:
3116
+ logger.error(f"Error downloading model: {e}")
3117
+ raise EnvironmentError(f"Failed to find or download model '{pretrained_model_name_or_path}': {e}")
3118
+
3119
+ if not resolved_model_path.is_dir():
3120
+ raise EnvironmentError(f"Cannot find model directory at {resolved_model_path}")
3121
 
3122
 
3123
  # Helper function to resolve paths relative to the main model directory