File size: 7,355 Bytes
d61c73b
 
 
 
f5ec16f
7a7d9e7
 
d61c73b
 
 
 
f5ec16f
 
 
4ea0a00
f5ec16f
 
 
 
 
 
 
d61c73b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5ec16f
 
 
 
 
 
 
 
 
 
7a7d9e7
 
f5ec16f
 
 
d61c73b
 
 
 
 
 
 
 
 
 
 
 
f5ec16f
 
 
 
 
 
 
 
 
 
d61c73b
 
f5ec16f
 
 
 
 
 
 
 
 
 
4ea0a00
 
d61c73b
 
 
 
 
 
 
 
f5ec16f
d61c73b
 
 
 
 
 
 
 
 
 
 
 
880916c
f5ec16f
7a7d9e7
d61c73b
f5ec16f
 
 
 
d61c73b
 
 
7a7d9e7
f5ec16f
 
d61c73b
7a7d9e7
d61c73b
 
 
 
 
 
f5ec16f
7a7d9e7
f5ec16f
7a7d9e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d61c73b
 
7a7d9e7
 
 
 
 
 
 
d61c73b
7a7d9e7
d61c73b
 
 
 
 
 
 
 
 
 
 
7a7d9e7
 
 
d61c73b
 
7a7d9e7
 
d61c73b
 
f5ec16f
 
7a7d9e7
 
 
 
 
 
 
 
 
 
 
 
f5ec16f
7a7d9e7
 
 
d61c73b
 
 
 
 
 
b18febb
d61c73b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import os
import re
import time
import math
import io
from typing import List, Tuple, Optional

import numpy as np
import gradio as gr
import soundfile as sf
from scipy.signal import resample_poly
from scipy.io import wavfile as wav_write
from pymongo import MongoClient
from gridfs import GridFS

# MongoDB configuration via environment variables
MONGO_URI = os.getenv("MONGO_URI", "")
MONGO_DB = os.getenv("MONGO_DB", "spells")
MONGO_BUCKET = os.getenv("MONGO_BUCKET", "recordings")

_mongo_client: Optional[MongoClient] = None
_mongo_fs: Optional[GridFS] = None

# Fixed target sample rate for ML training
TARGET_SR = 16000

# Spells to collect
SPELLS = [
    "Lumos",
    "Nox",
    "Alohomora",
    "Wingardium Leviosa",
    "Accio",
    "Reparo",
]


def sanitize_username(name: Optional[str]) -> str:
    """Sanitize username for safe filenames.
    - only keep a-z, 0-9, dash and underscore
    - collapse whitespace to underscore
    - default to 'anon' if empty
    """
    if not name:
        return "anon"
    # normalize whitespace then strip
    name = re.sub(r"\s+", "_", name.strip())
    # keep safe chars only
    name = re.sub(r"[^a-zA-Z0-9_-]", "", name)
    return name.lower() or "anon"


def to_mono(audio: np.ndarray) -> np.ndarray:
    if audio.ndim == 2:
        # average channels to mono
        return audio.mean(axis=1)
    return audio


def resample_to_target(audio: np.ndarray, sr: int, target_sr: int = TARGET_SR) -> np.ndarray:
    if sr == target_sr:
        return audio
    # rational resampling factors
    g = math.gcd(sr, target_sr)
    up = target_sr // g
    down = sr // g
    return resample_poly(audio, up=up, down=down)


def get_gridfs() -> Optional[GridFS]:
    global _mongo_client, _mongo_fs
    if not MONGO_URI:
        return None
    if _mongo_fs is not None:
        return _mongo_fs
    _mongo_client = MongoClient(MONGO_URI)
    db = _mongo_client[MONGO_DB]
    _mongo_fs = GridFS(db, collection=MONGO_BUCKET)
    return _mongo_fs


def save_one_from_path(filepath: Optional[str], spell: str, username: str) -> Optional[str]:
    """Load an audio file (from mic/upload), process to 16k mono, and store in MongoDB GridFS.
    Returns inserted file id (as str) or None if no audio provided / DB not configured.
    """
    if not filepath:
        return None

    audio, sr = sf.read(filepath, dtype="float32", always_2d=False)
    if audio is None or (isinstance(audio, np.ndarray) and audio.size == 0):
        return None

    audio = to_mono(np.asarray(audio))
    audio = resample_to_target(audio, sr, TARGET_SR)
    audio = np.clip(audio, -1.0, 1.0)

    # Convert to int16 PCM bytes in-memory
    pcm16 = (audio * 32767.0).astype(np.int16)
    buf = io.BytesIO()
    wav_write.write(buf, TARGET_SR, pcm16)
    wav_bytes = buf.getvalue()

    fs = get_gridfs()
    if fs is None:
        return None

    ts = int(time.time() * 1000)
    spell_slug = re.sub(r"[^a-zA-Z0-9]+", "_", spell).strip("_").lower()
    filename = f"{spell_slug}_{username}_{ts}.wav"
    metadata = {
        "username": username,
        "spell": spell,
        "timestamp_ms": ts,
        "sample_rate": TARGET_SR,
        "format": "wav",
    }
    file_id = fs.put(wav_bytes, filename=filename, contentType="audio/wav", metadata=metadata)
    return str(file_id)


def submit_recordings(
    username: str,
    lumos_path: Optional[str],
    nox_path: Optional[str],
    alohomora_path: Optional[str],
    wingardium_path: Optional[str],
    accio_path: Optional[str],
    reparo_path: Optional[str],
) -> Tuple[str, int]:
    user = sanitize_username(username)

    pairs: List[Tuple[str, Optional[str]]] = [
        ("Lumos", lumos_path),
        ("Nox", nox_path),
        ("Alohomora", alohomora_path),
        ("Wingardium Leviosa", wingardium_path),
        ("Accio", accio_path),
        ("Reparo", reparo_path),
    ]

    saved = []
    skipped = []
    inserted = 0

    for spell, path in pairs:
        file_id = save_one_from_path(path, spell, user)
        if file_id:
            saved.append(f"{spell} -> id {file_id}")
            inserted += 1
        else:
            skipped.append(spell)

    lines = []
    if not MONGO_URI:
        lines.append("Database not configured: set MONGO_URI secret in the Space.")
    if saved:
        lines.append("Saved recordings:")
        lines += [f"- {s}" for s in saved]
    if skipped:
        lines.append("")
        lines.append("Missing (not provided):")
        lines += [f"- {s}" for s in skipped]
    if not lines:
        return "No audio captured. Please record at least one spell.", 0

    return "\n".join(lines), inserted


def count_selected(
    lumos_path: Optional[str],
    nox_path: Optional[str],
    alohomora_path: Optional[str],
    wingardium_path: Optional[str],
    accio_path: Optional[str],
    reparo_path: Optional[str],
) -> str:
    paths = [lumos_path, nox_path, alohomora_path, wingardium_path, accio_path, reparo_path]
    n = sum(1 for p in paths if p)
    return f"Selected: {n}/6"


def build_ui() -> gr.Blocks:
    with gr.Blocks(title="Spell Recorder") as demo:
        gr.Markdown("""
        # Spell Recorder
        Record any of the listed spells and press Submit. You can use your microphone directly (preferred) or upload a file.

        Spells to collect: Lumos, Nox, Alohomora, Wingardium Leviosa, Accio, Reparo.
        """)

        with gr.Row():
            username = gr.Textbox(label="Your Name (for filename)", placeholder="e.g., harry_p" , autofocus=True)

        with gr.Row():
            with gr.Column():
                lumos = gr.Audio(label="Lumos", sources=["microphone", "upload"], type="filepath")
                nox = gr.Audio(label="Nox", sources=["microphone", "upload"], type="filepath")
                alohomora = gr.Audio(label="Alohomora", sources=["microphone", "upload"], type="filepath")
            with gr.Column():
                wingardium = gr.Audio(label="Wingardium Leviosa", sources=["microphone", "upload"], type="filepath")
                accio = gr.Audio(label="Accio", sources=["microphone", "upload"], type="filepath")
                reparo = gr.Audio(label="Reparo", sources=["microphone", "upload"], type="filepath")

        with gr.Row():
            selected_counter = gr.Markdown(value="Selected: 0/6")

        submit = gr.Button("Submit")
        result = gr.Markdown()
        submitted_count = gr.Number(label="New files saved this submit", value=0)

        submit.click(
            fn=submit_recordings,
            inputs=[username, lumos, nox, alohomora, wingardium, accio, reparo],
            outputs=[result, submitted_count],
        )

        # Live counter updates when any audio input changes
        for comp in [lumos, nox, alohomora, wingardium, accio, reparo]:
            comp.change(
                fn=count_selected,
                inputs=[lumos, nox, alohomora, wingardium, accio, reparo],
                outputs=[selected_counter],
            )

        gr.Markdown("""
        Notes:
        - Submissions are stored directly in MongoDB (GridFS) using environment secrets.
        - 16 kHz mono WAV is used to make model training consistent.
        - You don't have to record all spells at once—submit whatever you have.
        """)

    return demo


demo = build_ui()

if __name__ == "__main__":
    demo.launch()