| """ |
| Data models for dataset builder. |
| """ |
|
|
| from dataclasses import dataclass, asdict |
| from datetime import datetime |
| from typing import Any, Dict, Optional |
| import uuid |
|
|
|
|
| SUPPORTED_AUDIO_FORMATS = {".wav", ".mp3", ".flac", ".ogg", ".opus"} |
|
|
|
|
| @dataclass |
| class AudioSample: |
| """Represents a single audio sample with its metadata.""" |
|
|
| id: str = "" |
| audio_path: str = "" |
| filename: str = "" |
| caption: str = "" |
| genre: str = "" |
| lyrics: str = "[Instrumental]" |
| raw_lyrics: str = "" |
| formatted_lyrics: str = "" |
| bpm: Optional[int] = None |
| keyscale: str = "" |
| timesignature: str = "" |
| duration: int = 0 |
| language: str = "unknown" |
| is_instrumental: bool = True |
| custom_tag: str = "" |
| labeled: bool = False |
| prompt_override: Optional[str] = None |
|
|
| def __post_init__(self): |
| if not self.id: |
| self.id = str(uuid.uuid4())[:8] |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert to dictionary.""" |
| return asdict(self) |
|
|
| @classmethod |
| def from_dict(cls, data: Dict[str, Any]) -> "AudioSample": |
| """Create from dictionary. |
| |
| Handles backward compatibility for datasets without raw_lyrics/formatted_lyrics/genre. |
| """ |
| valid_keys = {f.name for f in cls.__dataclass_fields__.values()} |
| filtered_data = {k: v for k, v in data.items() if k in valid_keys} |
| return cls(**filtered_data) |
|
|
| def get_full_caption(self, tag_position: str = "prepend") -> str: |
| """Get caption with custom tag applied.""" |
| if not self.custom_tag: |
| return self.caption |
|
|
| if tag_position == "prepend": |
| return f"{self.custom_tag}, {self.caption}" if self.caption else self.custom_tag |
| if tag_position == "append": |
| return f"{self.caption}, {self.custom_tag}" if self.caption else self.custom_tag |
| if tag_position == "replace": |
| return self.custom_tag |
| return self.caption |
|
|
| def get_full_genre(self, tag_position: str = "prepend") -> str: |
| """Get genre with custom tag applied.""" |
| if not self.custom_tag: |
| return self.genre |
|
|
| if tag_position == "prepend": |
| return f"{self.custom_tag}, {self.genre}" if self.genre else self.custom_tag |
| if tag_position == "append": |
| return f"{self.genre}, {self.custom_tag}" if self.genre else self.custom_tag |
| if tag_position == "replace": |
| return self.custom_tag |
| return self.genre |
|
|
| def get_training_prompt(self, tag_position: str = "prepend", use_genre: bool = False) -> str: |
| """Get the prompt to use for training.""" |
| if self.prompt_override == "genre": |
| return self.get_full_genre(tag_position) |
| if self.prompt_override == "caption": |
| return self.get_full_caption(tag_position) |
| if use_genre: |
| return self.get_full_genre(tag_position) |
| return self.get_full_caption(tag_position) |
|
|
| def has_raw_lyrics(self) -> bool: |
| """Check if sample has user-provided raw lyrics from .txt file.""" |
| return bool(self.raw_lyrics and self.raw_lyrics.strip()) |
|
|
| def has_formatted_lyrics(self) -> bool: |
| """Check if sample has LM-formatted lyrics.""" |
| return bool(self.formatted_lyrics and self.formatted_lyrics.strip()) |
|
|
|
|
| @dataclass |
| class DatasetMetadata: |
| """Metadata for the entire dataset.""" |
|
|
| name: str = "untitled_dataset" |
| custom_tag: str = "" |
| tag_position: str = "prepend" |
| created_at: str = "" |
| num_samples: int = 0 |
| all_instrumental: bool = True |
| genre_ratio: int = 0 |
|
|
| def __post_init__(self): |
| if not self.created_at: |
| self.created_at = datetime.now().isoformat() |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert to dictionary.""" |
| return asdict(self) |
|
|