File size: 2,760 Bytes
6c2da0c
a7584b6
6c2da0c
a7584b6
1e28077
bd8d5b3
a7584b6
 
 
 
 
 
bd8d5b3
 
1e28077
 
bd8d5b3
1e28077
 
a7584b6
1e28077
 
bd8d5b3
 
 
 
 
 
 
 
1e28077
bd8d5b3
 
 
 
 
 
 
a7584b6
1e28077
a7584b6
 
1e28077
 
a7584b6
 
 
 
 
 
 
 
bd8d5b3
a7584b6
 
bd8d5b3
a7584b6
 
 
 
 
bd8d5b3
 
a7584b6
1e28077
a7584b6
 
 
1e28077
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import os

def generate_html(training_text, iterations, gen_length, layer1, layer2):
    safe_text = training_text.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n")

    html_content = f"""
    <html>
    <head>
      <script src="https://karpathy.github.io/recurrentjs/recurrent.js"></script>
    </head>
    <body>
      <h3>Live LSTM Training Output (updates every 10s)</h3>
      <pre id="output">Starting training...</pre>
      <script>
        const text = "{safe_text}".split("");
        const totalIterations = {iterations};
        const genLength = {gen_length};

        const lstm = new RNN("lstm", {{ hiddenSizes: [{layer1}, {layer2}] }});
        const trainer = new RNNTrainer(lstm, {{ learningRate: 0.01, momentum: 0.1, batchSize: 5 }});

        let iteration = 0;
        const interval = setInterval(() => {{
          for (let i = 0; i < 5 && iteration < totalIterations; i++, iteration++) {{
            const idx = Math.floor(Math.random() * (text.length - 10));
            const input = text.slice(idx, idx + 5);
            const output = text.slice(idx + 1, idx + 6);
            trainer.train(input, output);
          }}

          const sample = lstm.sample(["H"], genLength).join("");
          document.getElementById("output").innerText = `Epochs completed: ${iteration} / ${totalIterations}\\n\\n` + sample;

          if (iteration >= totalIterations) {{
            clearInterval(interval);
            document.getElementById("output").innerText += "\\n\\nTraining complete!";
          }}
        }}, 10000);
      </script>
    </body>
    </html>
    """
    
    # Save the HTML to the /tmp directory (accessible in Hugging Face Spaces)
    html_path = "/tmp/train.html"
    with open(html_path, "w") as f:
        f.write(html_content)
    
    # Return iframe HTML to embed the file
    iframe_code = f'<iframe src="file={html_path}" width="100%" height="400px"></iframe>'
    return iframe_code

with gr.Blocks() as demo:
    gr.Markdown("## 🧠 Live recurrent.js LSTM Trainer (in-browser training)")

    training_text = gr.Textbox(label="Training Text", lines=6, placeholder="Paste text here")
    iterations = gr.Slider(10, 200, value=50, step=10, label="Training Epochs")
    gen_length = gr.Slider(20, 500, value=100, step=10, label="Characters to Generate")
    layer1 = gr.Slider(16, 256, value=128, step=16, label="Neurons in Layer 1")
    layer2 = gr.Slider(16, 256, value=128, step=16, label="Neurons in Layer 2")

    run_button = gr.Button("Start Training")
    html_output = gr.HTML()

    run_button.click(fn=generate_html,
                     inputs=[training_text, iterations, gen_length, layer1, layer2],
                     outputs=html_output)

demo.launch()