File size: 5,781 Bytes
8d7f196
 
 
 
ec123d0
8d7f196
 
044def4
ec123d0
 
 
 
 
 
2b36725
044def4
 
 
 
 
cab9b9b
ec123d0
044def4
cab9b9b
 
 
2b36725
 
 
044def4
cab9b9b
2b36725
ec123d0
cab9b9b
044def4
2b36725
ec123d0
cab9b9b
2b36725
cab9b9b
ec123d0
044def4
 
 
 
 
 
 
 
 
ec123d0
044def4
ec123d0
 
 
 
 
 
044def4
ec123d0
 
 
cab9b9b
 
ec123d0
 
044def4
 
ec123d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
044def4
ec123d0
 
 
 
 
 
 
 
 
 
 
 
 
044def4
ec123d0
 
 
 
cab9b9b
2b36725
044def4
2b36725
cab9b9b
2b36725
cab9b9b
8d7f196
 
044def4
cab9b9b
 
2b36725
8d7f196
cab9b9b
 
2b36725
cab9b9b
2b36725
044def4
cab9b9b
2b36725
044def4
 
 
 
ec123d0
cab9b9b
2b36725
044def4
 
 
 
 
 
 
 
 
cab9b9b
 
8d7f196
cab9b9b
2b36725
8d7f196
cab9b9b
 
ec123d0
2b36725
cab9b9b
8d7f196
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import requests
import base64
import io
import time
import soundfile as sf

# ── Backend config ─────────────────────────────────────────────────────────────
API_URL  = "https://bustled-hertha-unprojective.ngrok-free.dev/generate"
BASE_URL = API_URL.replace("/generate", "")
HEADERS  = {
    "ngrok-skip-browser-warning": "true",
    "Content-Type": "application/json",
}

CFG = 7.0  # fixed optimal value from paper β€” not exposed in UI


# ── Generation function ────────────────────────────────────────────────────────
def generate(prompt, duration, steps):
    try:
        print(f"πŸ“€ Sending request β†’ prompt='{prompt[:60]}...' duration={duration}s steps={steps}")

        res = requests.post(
            API_URL,
            json={
                "prompt":    prompt,
                "duration":  float(duration),
                "steps":     int(steps),
                "cfg_scale": CFG,
            },
            headers=HEADERS,
            timeout=30,
        )

        print(f"STATUS: {res.status_code}")

        if res.status_code != 200:
            print(f"❌ Backend error: {res.text[:500]}")
            return None

        if not res.text:
            print("❌ Empty response from backend")
            return None

        try:
            data = res.json()
        except Exception:
            print(f"❌ Invalid JSON (got HTML?): {res.text[:300]}")
            return None

        # ── Sync backend: audio returned immediately ───────────────────────────
        if "audio" in data:
            audio_bytes = base64.b64decode(data["audio"])
            audio, sr   = sf.read(io.BytesIO(audio_bytes))
            print(f"βœ… Got audio instantly: {len(audio)/sr:.2f}s @ {sr}Hz")
            return sr, audio

        # ── Async backend: job_id returned, poll for result ────────────────────
        job_id = data.get("job_id")
        if not job_id:
            print(f"❌ No audio and no job_id in response: {data}")
            return None

        print(f"⏳ Job queued: {job_id} β€” polling for result...")

        POLL_INTERVAL = 5    # seconds between each poll
        MAX_WAIT      = 600  # 10 minutes max total wait

        for elapsed in range(0, MAX_WAIT, POLL_INTERVAL):
            time.sleep(POLL_INTERVAL)
            print(f"πŸ”„ Polling... ({elapsed + POLL_INTERVAL}s elapsed)")

            try:
                poll = requests.get(
                    f"{BASE_URL}/result/{job_id}",
                    headers=HEADERS,
                    timeout=10,
                )
            except Exception as e:
                print(f"⚠️ Poll request failed: {e} β€” retrying")
                continue

            if poll.status_code != 200:
                print(f"⚠️ Poll returned {poll.status_code} β€” retrying")
                continue

            job    = poll.json()
            status = job.get("status")
            print(f"   status = {status}")

            if status == "error":
                print(f"❌ Job failed: {job.get('error')}")
                return None

            if status == "done" and "audio" in job:
                audio_bytes = base64.b64decode(job["audio"])
                audio, sr   = sf.read(io.BytesIO(audio_bytes))
                print(f"βœ… Got audio: {len(audio)/sr:.2f}s @ {sr}Hz")
                return sr, audio

            if status not in ("pending", "processing", "running", "done"):
                print(f"⚠️ Unknown status '{status}' β€” continuing to poll")

        print("❌ Timed out after 10 minutes")
        return None

    except requests.exceptions.Timeout:
        print("❌ Initial request timed out β€” Colab backend may be busy")
        return None
    except Exception as e:
        print(f"❌ Request failed: {e}")
        return None


# ── UI ─────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="AutoMix AI 🎡") as demo:
    gr.Markdown("# 🎡 AutoMix AI Beat Generator")
    gr.Markdown("Generate AI beats using a diffusion model fine-tuned on trap/rap/R&B πŸš€")

    with gr.Row():
        with gr.Column():
            prompt_in = gr.Textbox(
                label="🎧 Prompt",
                placeholder="A dark trap beat at 140 BPM in C minor, featuring 808 bass and synth bells.",
                lines=4,
            )
            duration_in = gr.Slider(
                minimum=5,
                maximum=95,
                value=30,
                step=1,
                label="⏱ Duration (seconds)",
            )
            steps_in = gr.Slider(
                minimum=20,
                maximum=300,
                value=100,
                step=10,
                label="βš™οΈ Diffusion Steps  (more = better quality, slower)",
            )
            gr.Markdown(
                "> ⚠️ Durations above 47s take significantly longer to generate. "
                "Keep steps ≀ 150 for beats over 60s to avoid timeouts."
            )
            btn = gr.Button("πŸš€ Generate Beat", variant="primary")

        with gr.Column():
            output = gr.Audio(label="🎡 Generated Beat", type="numpy")

    btn.click(
        fn=generate,
        inputs=[prompt_in, duration_in, steps_in],
        outputs=output,
    )

demo.launch()