File size: 4,649 Bytes
eb1aec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import yaml

def load_config():
    """Load configuration from local.yaml if exists"""
    
    if os.path.exists("./configs/local.yaml"):
        config_path = "./configs/local.yaml"
    elif os.path.exists("./configs/modelscope.yaml"):
        config_path = "./configs/modelscope.yaml"
    elif os.path.exists("./configs/huggingface.yaml"):
        config_path = "./configs/huggingface.yaml"
    else:
        print("No local.yaml found, using default configurations")
        return None
    print(f"Loading configuration from {config_path}")
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)
    
def resolve_path(path_str):
    """
    Resolve path string, supporting HuggingFace Hub downloads and ModelScope downloads.
    Dataset Format: 
      - hf://repo_owner/repo_name/path/to/file (HuggingFace)
      - ms://repo_owner/repo_name/path/to/file (ModelScope)
    
    Note: repo_id contains '/'
    """
    if path_str is None:
        return None
    
    if isinstance(path_str, str):
        if path_str.startswith("hf://"):
            try:
                from huggingface_hub import hf_hub_download
                # Parse: hf://owner/repo/path/to/file
                # Split into at most 3 parts: owner, repo, filename
                path_without_prefix = path_str[5:]  # Remove "hf://"
                parts = path_without_prefix.split('/', 2)  # Split into owner, repo, filename
                
                if len(parts) >= 3:
                    repo_id = f"{parts[0]}/{parts[1]}"  # owner/repo
                    filename = parts[2]  # path/to/file
                    print(f"Downloading from HuggingFace: {repo_id}/{filename}")
                    return hf_hub_download(repo_id, filename, repo_type='dataset')
                else:
                    print(f"Invalid HuggingFace path format: {path_str}")
                    print(f"Expected format: hf://owner/repo/path/to/file")
                    return None
            except ImportError:
                print("huggingface_hub not installed, cannot download from HuggingFace")
                return None
            except Exception as e:
                print(f"Error downloading from HuggingFace: {e}")
                return None
        
        elif path_str.startswith("ms://"):
            try:
                from modelscope.hub.snapshot_download import snapshot_download
                # Parse: ms://owner/repo/path/to/file
                path_without_prefix = path_str[5:]  # Remove "ms://"
                parts = path_without_prefix.split('/', 2)  # Split into owner, repo, filename
                
                if len(parts) >= 3:
                    repo_id = f"{parts[0]}/{parts[1]}"  # owner/repo
                    filename = parts[2]  # path/to/file
                    print(f"Downloading from ModelScope: {repo_id}/{filename}")
                    # Use snapshot_download with allow_file_pattern to download single file
                    cache_dir = snapshot_download(
                        repo_id=repo_id,
                        repo_type='dataset',
                        allow_file_pattern=filename  # Only download this specific file
                    )
                    # Return the full path to the downloaded file
                    downloaded_file = os.path.join(cache_dir, filename)
                    if os.path.exists(downloaded_file):
                        return downloaded_file
                    else:
                        print(f"File not found after download: {downloaded_file}")
                        return None
                else:
                    print(f"Invalid ModelScope path format: {path_str}")
                    print(f"Expected format: ms://owner/repo/path/to/file")
                    return None
            except Exception as e:
                print(f"Error downloading from ModelScope: {e}")
                import traceback
                traceback.print_exc()
                return None
    
    return path_str

def process_config(config):
    """Process config to resolve all paths"""
    if config is None:
        return None
    
    processed = {}
    for model_name, model_config in config.items():
        processed[model_name] = {}
        for key, value in model_config.items():
            if key.endswith('_path'):
                processed[model_name][key] = resolve_path(value)
            else:
                processed[model_name][key] = value
    
    return processed

def load_and_process_config():
    """Load and process configuration in one step"""
    config = load_config()
    return process_config(config)