|
|
import os |
|
|
import yaml |
|
|
|
|
|
def load_config(): |
|
|
"""Load configuration from local.yaml if exists""" |
|
|
|
|
|
if os.path.exists("./configs/local.yaml"): |
|
|
config_path = "./configs/local.yaml" |
|
|
elif os.path.exists("./configs/modelscope.yaml"): |
|
|
config_path = "./configs/modelscope.yaml" |
|
|
elif os.path.exists("./configs/huggingface.yaml"): |
|
|
config_path = "./configs/huggingface.yaml" |
|
|
else: |
|
|
print("No local.yaml found, using default configurations") |
|
|
return None |
|
|
print(f"Loading configuration from {config_path}") |
|
|
with open(config_path, 'r') as f: |
|
|
return yaml.safe_load(f) |
|
|
|
|
|
def resolve_path(path_str): |
|
|
""" |
|
|
Resolve path string, supporting HuggingFace Hub downloads and ModelScope downloads. |
|
|
Dataset Format: |
|
|
- hf://repo_owner/repo_name/path/to/file (HuggingFace) |
|
|
- ms://repo_owner/repo_name/path/to/file (ModelScope) |
|
|
|
|
|
Note: repo_id contains '/' |
|
|
""" |
|
|
if path_str is None: |
|
|
return None |
|
|
|
|
|
if isinstance(path_str, str): |
|
|
if path_str.startswith("hf://"): |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
path_without_prefix = path_str[5:] |
|
|
parts = path_without_prefix.split('/', 2) |
|
|
|
|
|
if len(parts) >= 3: |
|
|
repo_id = f"{parts[0]}/{parts[1]}" |
|
|
filename = parts[2] |
|
|
print(f"Downloading from HuggingFace: {repo_id}/{filename}") |
|
|
return hf_hub_download(repo_id, filename, repo_type='dataset') |
|
|
else: |
|
|
print(f"Invalid HuggingFace path format: {path_str}") |
|
|
print(f"Expected format: hf://owner/repo/path/to/file") |
|
|
return None |
|
|
except ImportError: |
|
|
print("huggingface_hub not installed, cannot download from HuggingFace") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Error downloading from HuggingFace: {e}") |
|
|
return None |
|
|
|
|
|
elif path_str.startswith("ms://"): |
|
|
try: |
|
|
from modelscope.hub.snapshot_download import snapshot_download |
|
|
|
|
|
path_without_prefix = path_str[5:] |
|
|
parts = path_without_prefix.split('/', 2) |
|
|
|
|
|
if len(parts) >= 3: |
|
|
repo_id = f"{parts[0]}/{parts[1]}" |
|
|
filename = parts[2] |
|
|
print(f"Downloading from ModelScope: {repo_id}/{filename}") |
|
|
|
|
|
cache_dir = snapshot_download( |
|
|
repo_id=repo_id, |
|
|
repo_type='dataset', |
|
|
allow_file_pattern=filename |
|
|
) |
|
|
|
|
|
downloaded_file = os.path.join(cache_dir, filename) |
|
|
if os.path.exists(downloaded_file): |
|
|
return downloaded_file |
|
|
else: |
|
|
print(f"File not found after download: {downloaded_file}") |
|
|
return None |
|
|
else: |
|
|
print(f"Invalid ModelScope path format: {path_str}") |
|
|
print(f"Expected format: ms://owner/repo/path/to/file") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Error downloading from ModelScope: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
return path_str |
|
|
|
|
|
def process_config(config): |
|
|
"""Process config to resolve all paths""" |
|
|
if config is None: |
|
|
return None |
|
|
|
|
|
processed = {} |
|
|
for model_name, model_config in config.items(): |
|
|
processed[model_name] = {} |
|
|
for key, value in model_config.items(): |
|
|
if key.endswith('_path'): |
|
|
processed[model_name][key] = resolve_path(value) |
|
|
else: |
|
|
processed[model_name][key] = value |
|
|
|
|
|
return processed |
|
|
|
|
|
def load_and_process_config(): |
|
|
"""Load and process configuration in one step""" |
|
|
config = load_config() |
|
|
return process_config(config) |
|
|
|