Ltlu commited on
Commit
bd8d5b3
·
verified ·
1 Parent(s): 1e28077

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -1,49 +1,60 @@
1
  import gradio as gr
2
 
3
- def make_html(training_text, iterations, gen_length):
4
- # Sanitize text for JS (escape quotes and backslashes)
5
  safe_text = training_text.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n")
6
-
7
  html = f"""
8
  <div>
9
- <h3>Output:</h3>
10
- <pre id="output">Training...</pre>
11
  <script src="https://karpathy.github.io/recurrentjs/recurrent.js"></script>
12
  <script>
13
  const text = "{safe_text}".split("");
14
- const iterations = {iterations};
15
  const genLength = {gen_length};
16
 
17
- const lstm = new RNN("lstm", {{ hiddenSizes: [128] }});
18
  const trainer = new RNNTrainer(lstm, {{ learningRate: 0.01, momentum: 0.1, batchSize: 5 }});
19
 
20
- for (let i = 0; i < iterations; i++) {{
21
- const idx = Math.floor(Math.random() * (text.length - 10));
22
- const input = text.slice(idx, idx + 5);
23
- const output = text.slice(idx + 1, idx + 6);
24
- trainer.train(input, output);
25
- }}
 
 
26
 
27
- const generated = lstm.sample(["H"], genLength).join("");
28
- document.getElementById("output").innerText = generated;
 
 
 
 
 
 
29
  </script>
30
  </div>
31
  """
32
  return html
33
 
34
  with gr.Blocks() as demo:
35
- gr.Markdown("## 🧠 LSTM Trainer with recurrent.js (Browser-Based)")
36
 
 
 
37
  with gr.Row():
38
- training_text = gr.Textbox(label="Training Text", lines=8, placeholder="Paste your training text here...")
39
-
 
40
  with gr.Row():
41
- iterations = gr.Slider(1, 100, value=20, step=1, label="Training Iterations")
42
- gen_length = gr.Slider(10, 500, value=100, step=10, label="Characters to Generate")
43
-
44
- run_button = gr.Button("Train and Generate")
45
  output_html = gr.HTML()
46
 
47
- run_button.click(fn=make_html, inputs=[training_text, iterations, gen_length], outputs=output_html)
48
 
49
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ def make_html(training_text, iterations, gen_length, layer1_size, layer2_size):
4
+ # Sanitize training text for JS
5
  safe_text = training_text.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n")
6
+
7
  html = f"""
8
  <div>
9
+ <h3>Live LSTM Training Output (updates every 10s)</h3>
10
+ <pre id="output">Starting training...</pre>
11
  <script src="https://karpathy.github.io/recurrentjs/recurrent.js"></script>
12
  <script>
13
  const text = "{safe_text}".split("");
14
+ const totalIterations = {iterations};
15
  const genLength = {gen_length};
16
 
17
+ const lstm = new RNN("lstm", {{ hiddenSizes: [{layer1_size}, {layer2_size}] }});
18
  const trainer = new RNNTrainer(lstm, {{ learningRate: 0.01, momentum: 0.1, batchSize: 5 }});
19
 
20
+ let iteration = 0;
21
+ const interval = setInterval(() => {{
22
+ for (let i = 0; i < 5 && iteration < totalIterations; i++, iteration++) {{
23
+ const idx = Math.floor(Math.random() * (text.length - 10));
24
+ const input = text.slice(idx, idx + 5);
25
+ const output = text.slice(idx + 1, idx + 6);
26
+ trainer.train(input, output);
27
+ }}
28
 
29
+ const sample = lstm.sample(["H"], genLength).join("");
30
+ document.getElementById("output").innerText = `Epochs completed: ${iteration} / ${totalIterations}\\n\\n` + sample;
31
+
32
+ if (iteration >= totalIterations) {{
33
+ clearInterval(interval);
34
+ document.getElementById("output").innerText += "\\n\\nTraining complete!";
35
+ }}
36
+ }}, 10000); // every 10 seconds
37
  </script>
38
  </div>
39
  """
40
  return html
41
 
42
  with gr.Blocks() as demo:
43
+ gr.Markdown("## 🧠 Live LSTM Trainer with recurrent.js (Browser-Based)")
44
 
45
+ training_text = gr.Textbox(label="Training Text", lines=6, placeholder="Paste your training data here...")
46
+
47
  with gr.Row():
48
+ iterations = gr.Slider(10, 200, value=50, step=10, label="Total Training Epochs")
49
+ gen_length = gr.Slider(20, 500, value=100, step=10, label="Characters to Generate")
50
+
51
  with gr.Row():
52
+ layer1_size = gr.Slider(16, 256, value=128, step=16, label="Neurons in Layer 1")
53
+ layer2_size = gr.Slider(16, 256, value=128, step=16, label="Neurons in Layer 2")
54
+
55
+ run_button = gr.Button("Start Training")
56
  output_html = gr.HTML()
57
 
58
+ run_button.click(fn=make_html, inputs=[training_text, iterations, gen_length, layer1_size, layer2_size], outputs=output_html)
59
 
60
  demo.launch()