Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| def generate_html(training_text, iterations, gen_length, layer1, layer2): | |
| safe_text = training_text.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n") | |
| html_content = f""" | |
| <html> | |
| <head> | |
| <script src="https://karpathy.github.io/recurrentjs/recurrent.js"></script> | |
| </head> | |
| <body> | |
| <h3>Live LSTM Training Output (updates every 10s)</h3> | |
| <pre id="output">Starting training...</pre> | |
| <script> | |
| const text = "{safe_text}".split(""); | |
| const totalIterations = {iterations}; | |
| const genLength = {gen_length}; | |
| const lstm = new RNN("lstm", {{ hiddenSizes: [{layer1}, {layer2}] }}); | |
| const trainer = new RNNTrainer(lstm, {{ learningRate: 0.01, momentum: 0.1, batchSize: 5 }}); | |
| let iteration = 0; | |
| const interval = setInterval(() => {{ | |
| for (let i = 0; i < 5 && iteration < totalIterations; i++, iteration++) {{ | |
| const idx = Math.floor(Math.random() * (text.length - 10)); | |
| const input = text.slice(idx, idx + 5); | |
| const output = text.slice(idx + 1, idx + 6); | |
| trainer.train(input, output); | |
| }} | |
| const sample = lstm.sample(["H"], genLength).join(""); | |
| document.getElementById("output").innerText = `Epochs completed: ${iteration} / ${totalIterations}\\n\\n` + sample; | |
| if (iteration >= totalIterations) {{ | |
| clearInterval(interval); | |
| document.getElementById("output").innerText += "\\n\\nTraining complete!"; | |
| }} | |
| }}, 10000); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # Save the HTML to the /tmp directory (accessible in Hugging Face Spaces) | |
| html_path = "/tmp/train.html" | |
| with open(html_path, "w") as f: | |
| f.write(html_content) | |
| # Return iframe HTML to embed the file | |
| iframe_code = f'<iframe src="file={html_path}" width="100%" height="400px"></iframe>' | |
| return iframe_code | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🧠 Live recurrent.js LSTM Trainer (in-browser training)") | |
| training_text = gr.Textbox(label="Training Text", lines=6, placeholder="Paste text here") | |
| iterations = gr.Slider(10, 200, value=50, step=10, label="Training Epochs") | |
| gen_length = gr.Slider(20, 500, value=100, step=10, label="Characters to Generate") | |
| layer1 = gr.Slider(16, 256, value=128, step=16, label="Neurons in Layer 1") | |
| layer2 = gr.Slider(16, 256, value=128, step=16, label="Neurons in Layer 2") | |
| run_button = gr.Button("Start Training") | |
| html_output = gr.HTML() | |
| run_button.click(fn=generate_html, | |
| inputs=[training_text, iterations, gen_length, layer1, layer2], | |
| outputs=html_output) | |
| demo.launch() |