File size: 4,006 Bytes
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Data models for dataset builder.
"""

from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Any, Dict, Optional
import uuid


SUPPORTED_AUDIO_FORMATS = {".wav", ".mp3", ".flac", ".ogg", ".opus"}


@dataclass
class AudioSample:
    """Represents a single audio sample with its metadata."""

    id: str = ""
    audio_path: str = ""
    filename: str = ""
    caption: str = ""
    genre: str = ""  # Genre tags from LLM
    lyrics: str = "[Instrumental]"
    raw_lyrics: str = ""  # Original user-provided lyrics from .txt file
    formatted_lyrics: str = ""  # LM-formatted lyrics
    bpm: Optional[int] = None
    keyscale: str = ""
    timesignature: str = ""
    duration: int = 0
    language: str = "unknown"
    is_instrumental: bool = True
    custom_tag: str = ""
    labeled: bool = False
    prompt_override: Optional[str] = None  # None=use global ratio, "caption" or "genre"

    def __post_init__(self):
        if not self.id:
            self.id = str(uuid.uuid4())[:8]

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return asdict(self)

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "AudioSample":
        """Create from dictionary.

        Handles backward compatibility for datasets without raw_lyrics/formatted_lyrics/genre.
        """
        valid_keys = {f.name for f in cls.__dataclass_fields__.values()}
        filtered_data = {k: v for k, v in data.items() if k in valid_keys}
        return cls(**filtered_data)

    def get_full_caption(self, tag_position: str = "prepend") -> str:
        """Get caption with custom tag applied."""
        if not self.custom_tag:
            return self.caption

        if tag_position == "prepend":
            return f"{self.custom_tag}, {self.caption}" if self.caption else self.custom_tag
        if tag_position == "append":
            return f"{self.caption}, {self.custom_tag}" if self.caption else self.custom_tag
        if tag_position == "replace":
            return self.custom_tag
        return self.caption

    def get_full_genre(self, tag_position: str = "prepend") -> str:
        """Get genre with custom tag applied."""
        if not self.custom_tag:
            return self.genre

        if tag_position == "prepend":
            return f"{self.custom_tag}, {self.genre}" if self.genre else self.custom_tag
        if tag_position == "append":
            return f"{self.genre}, {self.custom_tag}" if self.genre else self.custom_tag
        if tag_position == "replace":
            return self.custom_tag
        return self.genre

    def get_training_prompt(self, tag_position: str = "prepend", use_genre: bool = False) -> str:
        """Get the prompt to use for training."""
        if self.prompt_override == "genre":
            return self.get_full_genre(tag_position)
        if self.prompt_override == "caption":
            return self.get_full_caption(tag_position)
        if use_genre:
            return self.get_full_genre(tag_position)
        return self.get_full_caption(tag_position)

    def has_raw_lyrics(self) -> bool:
        """Check if sample has user-provided raw lyrics from .txt file."""
        return bool(self.raw_lyrics and self.raw_lyrics.strip())

    def has_formatted_lyrics(self) -> bool:
        """Check if sample has LM-formatted lyrics."""
        return bool(self.formatted_lyrics and self.formatted_lyrics.strip())


@dataclass
class DatasetMetadata:
    """Metadata for the entire dataset."""

    name: str = "untitled_dataset"
    custom_tag: str = ""
    tag_position: str = "prepend"
    created_at: str = ""
    num_samples: int = 0
    all_instrumental: bool = True
    genre_ratio: int = 0  # 0-100, percentage of samples using genre

    def __post_init__(self):
        if not self.created_at:
            self.created_at = datetime.now().isoformat()

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return asdict(self)