File size: 4,661 Bytes
63f0b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import json
import logging
from pathlib import Path
from typing import Dict, Any, Optional

logger = logging.getLogger(__name__)


def fast_scandir(dir_path, ext_list):
    import os
    subfolders, files = [], []
    # add starting period to extensions if needed
    ext_list = ['.'+x if x[0] != '.' else x for x in ext_list]

    try:
        for f in os.scandir(dir_path):
            try:
                if f.is_dir():
                    subfolders.append(f.path)
                elif f.is_file():
                    file_ext = os.path.splitext(f.name)[1].lower()
                    is_hidden = os.path.basename(f.path).startswith(".")

                    if file_ext in ext_list and not is_hidden:
                        files.append(f.path)
            except:
                pass
    except:
        pass

    for dir in list(subfolders):
        sf, f = fast_scandir(dir, ext_list)
        subfolders.extend(sf)
        files.extend(f)
    return subfolders, files


class SimpleAudioProcessor:

    def __init__(self, model_config_path: Optional[Path] = None):
        self.audio_extensions = (".wav", ".mp3", ".flac", ".m4a")
        
        # Load model config for info only
        if model_config_path and model_config_path.exists():
            with open(model_config_path, 'r') as f:
                model_config = json.load(f)
            self.sample_size = model_config.get("sample_size", 2097152)
            self.sample_rate = model_config.get("sample_rate", 44100)
            self.audio_channels = model_config.get("audio_channels", 2)
        else:
            # Defaults
            self.sample_size = 2097152
            self.sample_rate = 44100
            self.audio_channels = 2

    def load_prompts(self, prompts_file: Path) -> Dict[str, str]:
        prompts = {}
        try:
            with open(prompts_file, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line and not line.startswith('#') and '|' in line:
                        filename, prompt = line.split('|', 1)
                        prompts[filename.strip()] = prompt.strip()
        except Exception as e:
            logger.error(f"Error loading prompts file: {e}")
        return prompts

    def create_dataset_config(
        self,
        input_dir: Path,
        output_dir: Path,
        prompts_file: Optional[Path] = None
    ) -> Dict[str, Any]:
        # Find audio files
        audio_files = []
        for ext in self.audio_extensions:
            _, files = fast_scandir(str(input_dir), [ext[1:]])
            audio_files.extend(files)

        if not audio_files:
            raise ValueError(f"No audio files found in {input_dir}")

        logger.info(f"Found {len(audio_files)} audio files")

        # Create output directory
        output_dir.mkdir(exist_ok=True, parents=True)
        
        # Copy files to output directory (only if different directories)
        if input_dir != output_dir:
            import shutil
            for audio_file in audio_files:
                src_path = Path(audio_file)
                dst_path = output_dir / src_path.name
                
                if not dst_path.exists() or dst_path.stat().st_size != src_path.stat().st_size:
                    shutil.copy2(src_path, dst_path)
                    logger.info(f"Copied {src_path.name}")
        else:
            logger.info("Input and output directories are the same - no copying needed")

        # Create simple dataset config
        dataset_config = {
            "dataset_type": "audio_dir",
            "datasets": [
                {
                    "id": "custom_dataset", 
                    "path": str(output_dir),
                    "custom_metadata_module": "custom_metadata"
                }
            ],
            "random_crop": True,  # CRITICAL - enables random cropping during training
            "drop_last": True
        }

        # Save prompts if provided
        if prompts_file and prompts_file.exists():
            prompts = self.load_prompts(prompts_file)
            if prompts:
                metadata_file = output_dir / "prompts_metadata.json"
                with open(metadata_file, 'w') as f:
                    json.dump([{"file_name": k, "prompt": v} for k, v in prompts.items()], f, indent=2)
                logger.info(f"Saved prompts metadata")

        return {
            "dataset_config": dataset_config,
            "file_count": len(audio_files),
            "sample_size": self.sample_size,
            "sample_rate": self.sample_rate,
            "audio_channels": self.audio_channels
        }