Ltlu commited on
Commit
6c2da0c
·
verified ·
1 Parent(s): 6093371

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ html = """
4
+ <!DOCTYPE html>
5
+ <html>
6
+ <head>
7
+ <script src="https://karpathy.github.io/recurrentjs/recurrent.js"></script>
8
+ </head>
9
+ <body>
10
+ <h2>LSTM Text Trainer (Client-Side with recurrent.js)</h2>
11
+ <textarea id="inputText" rows="6" cols="60" placeholder="Paste training text here..."></textarea><br>
12
+ <button onclick="trainModel()">Train Model</button>
13
+ <button onclick="generateText()">Generate Text</button>
14
+ <pre id="output"></pre>
15
+
16
+ <script>
17
+ let model;
18
+
19
+ function trainModel() {
20
+ const text = document.getElementById('inputText').value;
21
+ const data = text.split(""); // Character-level
22
+
23
+ const lstm = new RNN("lstm", { hiddenSizes: [128] });
24
+ const trainer = new RNNTrainer(lstm, { learningRate: 0.01, momentum: 0.1, batchSize: 5 });
25
+
26
+ for (let i = 0; i < 20; i++) {
27
+ const idx = Math.floor(Math.random() * (data.length - 10));
28
+ const input = data.slice(idx, idx + 5);
29
+ const output = data.slice(idx + 1, idx + 6);
30
+ trainer.train(input, output);
31
+ }
32
+
33
+ model = lstm;
34
+ document.getElementById('output').innerText = "Training complete!";
35
+ }
36
+
37
+ function generateText() {
38
+ if (!model) {
39
+ alert("Train the model first.");
40
+ return;
41
+ }
42
+
43
+ let txt = model.sample(["H"], 100).join("");
44
+ document.getElementById('output').innerText = txt;
45
+ }
46
+ </script>
47
+ </body>
48
+ </html>
49
+ """
50
+
51
+ gr.Interface(fn=lambda: None, inputs=[], outputs=gr.HTML(html), live=True).launch()