Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import os | |
| from datetime import datetime | |
| from typing import List, Tuple | |
| from loguru import logger | |
| from .models import AudioSample, DatasetMetadata | |
| class SerializationMixin: | |
| """Save/load dataset JSON.""" | |
| def save_dataset(self, output_path: str, dataset_name: str = None) -> str: | |
| """Save the dataset to a JSON file.""" | |
| if not self.samples: | |
| return "β No samples to save" | |
| if dataset_name: | |
| self.metadata.name = dataset_name | |
| self.metadata.num_samples = len(self.samples) | |
| self.metadata.created_at = datetime.now().isoformat() | |
| dataset = { | |
| "metadata": self.metadata.to_dict(), | |
| "samples": [sample.to_dict() for sample in self.samples], | |
| } | |
| try: | |
| os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True) | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| json.dump(dataset, f, indent=2, ensure_ascii=False) | |
| return f"β Dataset saved to {output_path}\n{len(self.samples)} samples, tag: '{self.metadata.custom_tag}'" | |
| except Exception as e: | |
| logger.exception("Error saving dataset") | |
| return f"β Failed to save dataset: {str(e)}" | |
| def load_dataset(self, dataset_path: str) -> Tuple[List[AudioSample], str]: | |
| """Load a dataset from a JSON file.""" | |
| if not os.path.exists(dataset_path): | |
| return [], f"β Dataset not found: {dataset_path}" | |
| try: | |
| with open(dataset_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if "metadata" in data: | |
| meta_dict = data["metadata"] | |
| self.metadata = DatasetMetadata( | |
| name=meta_dict.get("name", "untitled"), | |
| custom_tag=meta_dict.get("custom_tag", ""), | |
| tag_position=meta_dict.get("tag_position", "prepend"), | |
| created_at=meta_dict.get("created_at", ""), | |
| num_samples=meta_dict.get("num_samples", 0), | |
| all_instrumental=meta_dict.get("all_instrumental", True), | |
| genre_ratio=meta_dict.get("genre_ratio", 0), | |
| ) | |
| self.samples = [] | |
| for sample_dict in data.get("samples", []): | |
| sample = AudioSample.from_dict(sample_dict) | |
| self.samples.append(sample) | |
| return self.samples, f"β Loaded {len(self.samples)} samples from {dataset_path}" | |
| except Exception as e: | |
| logger.exception("Error loading dataset") | |
| return [], f"β Failed to load dataset: {str(e)}" | |