Spaces:
Sleeping
Sleeping
File size: 2,760 Bytes
6c2da0c a7584b6 6c2da0c a7584b6 1e28077 bd8d5b3 a7584b6 bd8d5b3 1e28077 bd8d5b3 1e28077 a7584b6 1e28077 bd8d5b3 1e28077 bd8d5b3 a7584b6 1e28077 a7584b6 1e28077 a7584b6 bd8d5b3 a7584b6 bd8d5b3 a7584b6 bd8d5b3 a7584b6 1e28077 a7584b6 1e28077 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
import os
def generate_html(training_text, iterations, gen_length, layer1, layer2):
safe_text = training_text.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n")
html_content = f"""
<html>
<head>
<script src="https://karpathy.github.io/recurrentjs/recurrent.js"></script>
</head>
<body>
<h3>Live LSTM Training Output (updates every 10s)</h3>
<pre id="output">Starting training...</pre>
<script>
const text = "{safe_text}".split("");
const totalIterations = {iterations};
const genLength = {gen_length};
const lstm = new RNN("lstm", {{ hiddenSizes: [{layer1}, {layer2}] }});
const trainer = new RNNTrainer(lstm, {{ learningRate: 0.01, momentum: 0.1, batchSize: 5 }});
let iteration = 0;
const interval = setInterval(() => {{
for (let i = 0; i < 5 && iteration < totalIterations; i++, iteration++) {{
const idx = Math.floor(Math.random() * (text.length - 10));
const input = text.slice(idx, idx + 5);
const output = text.slice(idx + 1, idx + 6);
trainer.train(input, output);
}}
const sample = lstm.sample(["H"], genLength).join("");
document.getElementById("output").innerText = `Epochs completed: ${iteration} / ${totalIterations}\\n\\n` + sample;
if (iteration >= totalIterations) {{
clearInterval(interval);
document.getElementById("output").innerText += "\\n\\nTraining complete!";
}}
}}, 10000);
</script>
</body>
</html>
"""
# Save the HTML to the /tmp directory (accessible in Hugging Face Spaces)
html_path = "/tmp/train.html"
with open(html_path, "w") as f:
f.write(html_content)
# Return iframe HTML to embed the file
iframe_code = f'<iframe src="file={html_path}" width="100%" height="400px"></iframe>'
return iframe_code
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Live recurrent.js LSTM Trainer (in-browser training)")
training_text = gr.Textbox(label="Training Text", lines=6, placeholder="Paste text here")
iterations = gr.Slider(10, 200, value=50, step=10, label="Training Epochs")
gen_length = gr.Slider(20, 500, value=100, step=10, label="Characters to Generate")
layer1 = gr.Slider(16, 256, value=128, step=16, label="Neurons in Layer 1")
layer2 = gr.Slider(16, 256, value=128, step=16, label="Neurons in Layer 2")
run_button = gr.Button("Start Training")
html_output = gr.HTML()
run_button.click(fn=generate_html,
inputs=[training_text, iterations, gen_length, layer1, layer2],
outputs=html_output)
demo.launch() |