Magenta / app.py
Leo71288's picture
Update app.py
cab19ee verified
# app.py
import gradio as gr
from html import escape
# --- Magenta UI in HTML (UMD) styled to look identical to Gradio ---
INNER_HTML = r"""<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Magenta MusicRNN — UMD (in Gradio)</title>
<meta name="viewport" content="width=device-width, initial-scale=1" />
<!-- Load Tone.js THEN Magenta (UMD) -->
<script src="https://cdn.jsdelivr.net/npm/tone@14.7.77"></script>
<script src="https://cdn.jsdelivr.net/npm/@magenta/music@1.23.1"></script>
<style>
/* 1) CSS variables with fallbacks. They get overridden by the parent Gradio theme sync */
:root{
--radius-lg: 12px;
--block-background-fill: #ffffff;
--block-border-color: #e5e7eb;
--body-text-color: #111827;
--muted-text-color: #6b7280;
--shadow-drop: 0 1px 2px rgba(0,0,0,.05);
--button-primary-background-fill: #3b82f6;
--button-primary-background-fill-hover: #2563eb;
--button-primary-text-color: #ffffff;
--button-secondary-background-fill: #f9fafb;
--button-secondary-background-fill-hover: #f3f4f6;
--button-secondary-text-color: #111827;
--button-secondary-border-color: #e5e7eb;
--slider-track: #e5e7eb; /* light gray */
--slider-fill: var(--button-primary-background-fill);
}
/* 2) Base: blend into Gradio page (transparent bg, inherit font) */
html, body {
padding: 0; margin: 0;
background: transparent;
color: var(--body-text-color);
font-family: inherit; /* inherit Gradio font */
line-height: 1.4;
}
.wrap { padding: 0; } /* Gradio handles outer margins already */
.card{
max-width: 940px;
margin: 0 auto;
background: var(--block-background-fill);
border: 1px solid var(--block-border-color);
border-radius: var(--radius-lg);
box-shadow: var(--shadow-drop);
padding: 16px 16px 24px;
}
h1{ margin: 0 0 8px; font-size: 20px; font-weight: 600; }
.muted{ color: var(--muted-text-color); font-size: 13px; }
/* Grid labels / controls / values, Gradio-like */
.row{
display: grid;
grid-template-columns: 180px 1fr 72px;
gap: 12px;
align-items: center;
margin-top: 12px;
}
label{ font-size: 14px; color: var(--body-text-color); }
.mono{ font-family: ui-monospace, Menlo, Consolas, monospace; }
.buttons{
display: flex; gap: 10px; flex-wrap: wrap;
margin: 16px 0 8px;
}
button{
border-radius: calc(var(--radius-lg) - 4px);
padding: 10px 14px;
font-weight: 600;
font-size: 14px;
border: 1px solid transparent;
cursor: pointer;
transition: background-color .15s ease, border-color .15s ease, opacity .15s ease, transform .02s;
will-change: background-color, border-color;
}
button:active{ transform: translateY(0.5px); }
.primary{
background: var(--button-primary-background-fill);
color: var(--button-primary-text-color);
}
.primary:hover{ background: var(--button-primary-background-fill-hover); }
.secondary{
background: var(--button-secondary-background-fill);
color: var(--button-secondary-text-color);
border-color: var(--button-secondary-border-color);
}
.secondary:hover{
background: var(--button-secondary-background-fill-hover);
border-color: var(--button-secondary-border-color);
}
button:disabled{
opacity: .6;
cursor: not-allowed;
}
button:focus-visible{
outline: 2px solid var(--button-primary-background-fill);
outline-offset: 2px;
}
/* Progress bar, Gradio-style */
.track{
height: 10px;
background: var(--slider-track);
border-radius: 999px;
overflow: hidden;
border: 1px solid var(--block-border-color);
}
.bar{
height: 100%;
width: 0%;
background: var(--slider-fill);
transition: width .12s linear;
}
.meta{
display: flex;
justify-content: space-between;
align-items: center;
margin-top: 6px;
color: var(--muted-text-color);
font-size: 13px;
}
/* Range slider styled like Gradio (track + fill + thumb) */
input[type=range]{
-webkit-appearance: none; appearance: none;
width: 100%;
background: transparent;
height: 1.5rem;
}
/* Track */
input[type=range]::-webkit-slider-runnable-track{
height: 6px; border-radius: 999px; background: var(--slider-track);
}
input[type=range]::-moz-range-track{
height: 6px; border-radius: 999px; background: var(--slider-track);
}
/* Thumb */
input[type=range]::-webkit-slider-thumb{
-webkit-appearance: none; appearance: none;
width: 18px; height: 18px; border-radius: 50%;
background: #fff; border: 2px solid var(--slider-fill);
margin-top: -6px; /* vertically centers on 6px track */
box-shadow: 0 1px 1px rgba(0,0,0,.06);
}
input[type=range]::-moz-range-thumb{
width: 18px; height: 18px; border-radius: 50%;
background: #fff; border: 2px solid var(--slider-fill);
box-shadow: 0 1px 1px rgba(0,0,0,.06);
}
input[type=range]:focus-visible{
outline: none;
}
/* Right-aligned numeric value */
.val{
text-align: right;
color: var(--body-text-color);
font-weight: 600;
font-size: 13px;
}
/* Backend note (small text on the right) */
#backend{
align-self: center;
margin-left: auto;
font-size: 12px;
color: var(--muted-text-color);
}
</style>
</head>
<body>
<div class="wrap">
<div class="card" role="group" aria-label="Magenta MusicRNN">
<h1>Magenta MusicRNN</h1>
<div class="muted">Local generation + SoundFont (UMD).</div>
<div class="row">
<label for="steps">Steps (length)</label>
<input id="steps" type="range" min="8" max="128" value="32" step="1" aria-label="Steps (length)">
<div id="stepsVal" class="val mono">32</div>
</div>
<div class="row">
<label for="temp">Temperature</label>
<input id="temp" type="range" min="0.5" max="2.0" value="1.1" step="0.1" aria-label="Temperature">
<div id="tempVal" class="val mono">1.1</div>
</div>
<div class="buttons">
<button id="genBtn" class="primary" aria-label="Generate and play">🎼 Generate & Play</button>
<button id="stopBtn" class="secondary" aria-label="Stop playback">⏹️ Stop</button>
<button id="saveBtn" class="secondary" aria-label="Export as MIDI" disabled>💾 Export MIDI</button>
<span id="backend" class="muted"></span>
</div>
<div class="track" aria-hidden="true"><div id="bar" class="bar"></div></div>
<div class="meta">
<span id="phase">Initialization…</span>
<span id="ascii" class="mono">[--------------------] 0%</span>
</div>
<div id="log" class="muted" style="margin-top:10px">Log: ready.</div>
</div>
</div>
<script>
// ---- 1) Sync *all* CSS custom properties from the parent Gradio theme ----
function syncCssVarsFromParent(){
try{
const parentDoc = window.parent && window.parent.document;
if(!parentDoc) return;
const parentRoot = parentDoc.documentElement;
const parentStyles = getComputedStyle(parentRoot);
// Copy every CSS custom property (--xxx)
for(let i=0; i<parentStyles.length; i++){
const name = parentStyles[i];
if(name && name.startsWith('--')){
const val = parentStyles.getPropertyValue(name);
if(val) document.documentElement.style.setProperty(name, val);
}
}
}catch(e){ /* ignore if cross-origin (unlikely here) */ }
}
syncCssVarsFromParent();
// Re-sync in case Gradio theme switches at runtime
setInterval(syncCssVarsFromParent, 1000);
// ---- 2) UI helpers (range fill coloring like Gradio) ----
function updateRangeFill(el){
const min = +el.min || 0, max = +el.max || 100, val = +el.value || 0;
const pct = ((val - min) / (max - min)) * 100;
el.style.background = `linear-gradient(to right, var(--slider-fill) ${pct}%, var(--slider-track) ${pct}%)`;
}
const steps = document.getElementById('steps');
const stepsVal = document.getElementById('stepsVal');
const temp = document.getElementById('temp');
const tempVal = document.getElementById('tempVal');
[steps, temp].forEach(el=>{
updateRangeFill(el);
el.addEventListener('input', ()=>updateRangeFill(el));
});
steps.oninput = () => stepsVal.textContent = steps.value;
temp.oninput = () => tempVal.textContent = Number(temp.value).toFixed(1);
// ---- 3) Magenta/Tone + app logic (unchanged) ----
const mm = window.mm;
const CHECKPOINT = 'https://storage.googleapis.com/magentadata/js/checkpoints/music_rnn/basic_rnn';
const SOUNDFONT = 'https://storage.googleapis.com/magentadata/js/soundfonts/sgm_plus';
const rnn = new mm.MusicRNN(CHECKPOINT);
const player = new mm.SoundFontPlayer(SOUNDFONT);
const genBtn = document.getElementById('genBtn');
const stopBtn = document.getElementById('stopBtn');
const saveBtn = document.getElementById('saveBtn');
const bar = document.getElementById('bar');
const ascii = document.getElementById('ascii');
const phase = document.getElementById('phase');
const logEl = document.getElementById('log');
const backend = document.getElementById('backend');
function log(msg){ logEl.textContent = `[${new Date().toLocaleTimeString()}] ${msg}`; }
function setPhase(p){ phase.textContent = p; }
function render(p){ p=Math.max(0,Math.min(100,p)); bar.style.width = p+'%'; ascii.textContent = asciiBar(p); }
function asciiBar(p){ const n=20,f=Math.round(p/100*n); return '['+'-'.repeat(f)+' '.repeat(n-f)+`] ${Math.round(p)}%`; }
let tmr=null, prog=0;
function startProg(){ stopProg(); prog=0; render(0); tmr=setInterval(()=>{ const rem=Math.max(0,95-prog); const step=Math.max(.2,rem*.04)+(Math.random()*.6-.3); prog=Math.min(95,prog+step); render(prog); },120) }
function stopProg(){ if (tmr) clearInterval(tmr); tmr=null; }
function finishProg(){ stopProg(); render(100); }
const seed = mm.sequences.quantizeNoteSequence({
ticksPerQuarter:220, totalTime:2.5,
notes:[
{pitch:60,startTime:0.0,endTime:0.5},
{pitch:62,startTime:0.5,endTime:1.0},
{pitch:64,startTime:1.0,endTime:1.5},
{pitch:65,startTime:1.5,endTime:2.0},
{pitch:67,startTime:2.0,endTime:2.5},
],
}, 4);
(async () => {
try {
try { await mm.tf.setBackend('webgl'); await mm.tf.ready(); backend.textContent = 'TFJS backend: '+mm.tf.getBackend(); } catch {}
setPhase('Loading model…');
await rnn.initialize();
setPhase('Ready');
log('Model initialized.');
} catch(e){
setPhase('Init error'); log('Init error: '+(e?.message||e));
}
})();
document.getElementById('genBtn').onclick = async () => {
try {
if (window.Tone && window.Tone.context?.state !== 'running') { await window.Tone.start(); }
genBtn.disabled = true; saveBtn.disabled = true;
setPhase('Generating…'); startProg();
const stepsN = Number(steps.value);
const tempN = Number(temp.value);
const t0 = performance.now();
const seq = await rnn.continueSequence(seed, stepsN, tempN);
const t1 = performance.now();
setPhase('Loading instruments…');
await player.loadSamples(seq);
finishProg();
setPhase(`Playing (${stepsN} steps, temp ${tempN.toFixed(1)})`);
log(`Generated in ${(t1 - t0).toFixed(0)} ms`);
if (player?.isPlaying()) player.stop();
player.start(seq);
lastSeq = seq;
saveBtn.disabled = false;
} catch(e){
stopProg(); render(0); setPhase('Error'); log('Error: '+(e?.message||e)); console.error(e);
alert('Error during generation. See frame console.');
} finally {
genBtn.disabled = false;
}
};
let lastSeq = null;
document.getElementById('stopBtn').onclick = () => { if (player?.isPlaying()) player.stop(); setPhase('Stopped'); };
document.getElementById('saveBtn').onclick = () => {
if (!lastSeq) return alert('Please generate a sequence first.');
const midi = mm.sequenceProtoToMidi(lastSeq);
const blob = new Blob([midi], { type: 'audio/midi' });
const url = URL.createObjectURL(blob);
const a = Object.assign(document.createElement('a'), { href:url, download:'magenta_rnn.mid' });
document.body.appendChild(a); a.click(); a.remove(); URL.revokeObjectURL(url);
};
</script>
</body>
</html>
"""
# Wrap in an iframe via srcdoc (scripts inside will execute)
IFRAME = f"""
<iframe
srcdoc="{escape(INNER_HTML)}"
style="width:100%;height:760px;border:none;border-radius:12px;overflow:hidden;background:transparent;">
</iframe>
"""
with gr.Blocks(title="Magenta RNN (Gradio)") as demo:
gr.Markdown("## 🎶 Magenta MusicRNN inside a Gradio interface\nClick **Generate & Play** in the frame below.")
gr.HTML(IFRAME)
if __name__ == "__main__":
demo.launch()