Ltlu commited on
Commit
1e28077
·
verified ·
1 Parent(s): 6c2da0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -49
app.py CHANGED
@@ -1,51 +1,49 @@
1
  import gradio as gr
2
 
3
- html = """
4
- <!DOCTYPE html>
5
- <html>
6
- <head>
7
- <script src="https://karpathy.github.io/recurrentjs/recurrent.js"></script>
8
- </head>
9
- <body>
10
- <h2>LSTM Text Trainer (Client-Side with recurrent.js)</h2>
11
- <textarea id="inputText" rows="6" cols="60" placeholder="Paste training text here..."></textarea><br>
12
- <button onclick="trainModel()">Train Model</button>
13
- <button onclick="generateText()">Generate Text</button>
14
- <pre id="output"></pre>
15
-
16
- <script>
17
- let model;
18
-
19
- function trainModel() {
20
- const text = document.getElementById('inputText').value;
21
- const data = text.split(""); // Character-level
22
-
23
- const lstm = new RNN("lstm", { hiddenSizes: [128] });
24
- const trainer = new RNNTrainer(lstm, { learningRate: 0.01, momentum: 0.1, batchSize: 5 });
25
-
26
- for (let i = 0; i < 20; i++) {
27
- const idx = Math.floor(Math.random() * (data.length - 10));
28
- const input = data.slice(idx, idx + 5);
29
- const output = data.slice(idx + 1, idx + 6);
30
- trainer.train(input, output);
31
- }
32
-
33
- model = lstm;
34
- document.getElementById('output').innerText = "Training complete!";
35
- }
36
-
37
- function generateText() {
38
- if (!model) {
39
- alert("Train the model first.");
40
- return;
41
- }
42
-
43
- let txt = model.sample(["H"], 100).join("");
44
- document.getElementById('output').innerText = txt;
45
- }
46
- </script>
47
- </body>
48
- </html>
49
- """
50
-
51
- gr.Interface(fn=lambda: None, inputs=[], outputs=gr.HTML(html), live=True).launch()
 
1
  import gradio as gr
2
 
3
+ def make_html(training_text, iterations, gen_length):
4
+ # Sanitize text for JS (escape quotes and backslashes)
5
+ safe_text = training_text.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n")
6
+
7
+ html = f"""
8
+ <div>
9
+ <h3>Output:</h3>
10
+ <pre id="output">Training...</pre>
11
+ <script src="https://karpathy.github.io/recurrentjs/recurrent.js"></script>
12
+ <script>
13
+ const text = "{safe_text}".split("");
14
+ const iterations = {iterations};
15
+ const genLength = {gen_length};
16
+
17
+ const lstm = new RNN("lstm", {{ hiddenSizes: [128] }});
18
+ const trainer = new RNNTrainer(lstm, {{ learningRate: 0.01, momentum: 0.1, batchSize: 5 }});
19
+
20
+ for (let i = 0; i < iterations; i++) {{
21
+ const idx = Math.floor(Math.random() * (text.length - 10));
22
+ const input = text.slice(idx, idx + 5);
23
+ const output = text.slice(idx + 1, idx + 6);
24
+ trainer.train(input, output);
25
+ }}
26
+
27
+ const generated = lstm.sample(["H"], genLength).join("");
28
+ document.getElementById("output").innerText = generated;
29
+ </script>
30
+ </div>
31
+ """
32
+ return html
33
+
34
+ with gr.Blocks() as demo:
35
+ gr.Markdown("## 🧠 LSTM Trainer with recurrent.js (Browser-Based)")
36
+
37
+ with gr.Row():
38
+ training_text = gr.Textbox(label="Training Text", lines=8, placeholder="Paste your training text here...")
39
+
40
+ with gr.Row():
41
+ iterations = gr.Slider(1, 100, value=20, step=1, label="Training Iterations")
42
+ gen_length = gr.Slider(10, 500, value=100, step=10, label="Characters to Generate")
43
+
44
+ run_button = gr.Button("Train and Generate")
45
+ output_html = gr.HTML()
46
+
47
+ run_button.click(fn=make_html, inputs=[training_text, iterations, gen_length], outputs=output_html)
48
+
49
+ demo.launch()