File size: 3,109 Bytes
9e1e4ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef76c0c
9e1e4ee
a60f44b
ef76c0c
9e1e4ee
 
ef76c0c
 
 
9e1e4ee
 
4d18ab0
ef76c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
9e1e4ee
ef76c0c
9e1e4ee
ef76c0c
9e1e4ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d18ab0
 
a60f44b
9e1e4ee
 
 
 
 
 
 
 
 
ef76c0c
9e1e4ee
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# backend/data_manager.py

import os
import base64
import io
from typing import List, Optional, Any, Dict

import numpy as np
from datasets import load_dataset

from .config import AUDIO_DATASET_ID
from .models import Clip

try:
    import soundfile as sf
except ImportError:
    sf = None


class DataManager:
    """Handles loading and processing data from Hugging Face."""

    def __init__(self, dataset_id: str = AUDIO_DATASET_ID):
        self.dataset_id = dataset_id
        self._clips: Optional[List[Clip]] = None
        self._loading = False

    def _get_audio_data(self, audio_val) -> Optional[str]:
        """
        Handle audio data from HuggingFace dataset with LFS files.
        Returns file path or data URL that Gradio can handle.
        """
        try:
            array = None
            sr = None
            if isinstance(audio_val, dict):
                array = audio_val.get("array")
                sr = audio_val.get("sampling_rate")
            
            if array is None or sr is None:
                try:
                    array = audio_val["array"]
                    sr = audio_val["sampling_rate"]
                except Exception:
                    array = getattr(audio_val, "array", None)
                    sr = getattr(audio_val, "sampling_rate", None)

            if array is not None and sr is not None and sf is not None:
                # Convert to temporary file that Gradio can handle
                import tempfile
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                    sf.write(tmp_file.name, np.array(array), int(sr))
                    return tmp_file.name
        except Exception as e:
            print(f"[WARN] Failed to process audio data: {e}")

        print("[WARN] Could not process audio data for this example")
        return None

    def load_clips(self) -> List[Clip]:
        if self._clips is not None:
            return self._clips

        if self._loading:
            print("Dataset loading already in progress...")
            return []

        self._loading = True

        print(f"Loading dataset {self.dataset_id}...")
        dataset = load_dataset(self.dataset_id, split="train")

        clips: List[Clip] = []
        for row in dataset:
            audio_val = row.get("audio")

            audio_data = self._get_audio_data(audio_val)
            if audio_data is None:
                print(f"[WARN] Skipping clip {row.get('exercise_id')} – could not process audio data")
                continue

            clip = Clip(
                id=f"{row['model']}_{row['speaker']}_{row['exercise_id']}",
                model=row["model"],
                speaker=row["speaker"],
                exercise=row["exercise"],
                exercise_id=row["exercise_id"],
                transcript=row["rt"],
                audio_url=audio_data,  # file path or data for Gradio Audio
            )
            clips.append(clip)

        self._clips = clips
        self._loading = False
        print(f"Loaded {len(clips)} clips")
        return clips