WillemVH commited on
Commit
edf04a2
·
verified ·
1 Parent(s): 15eb2d4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +591 -0
app.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import pickle
10
+ from pathlib import Path
11
+ from datetime import datetime
12
+ import threading
13
+ import glob
14
+ from collections import Counter
15
+ import struct
16
+
17
+ class SimpleTokenizer:
18
+ """A simple tokenizer for faster startup"""
19
+ def __init__(self):
20
+ self.vocab = {}
21
+ self.inverse_vocab = {}
22
+ self.vocab_size = 0
23
+ self.pad_token = "<pad>"
24
+ self.pad_token_id = 0
25
+ self.eos_token = "<eos>"
26
+ self.eos_token_id = 1
27
+ self.unk_token = "<unk>"
28
+ self.unk_token_id = 2
29
+
30
+ # Start with basic tokens
31
+ self.add_token(self.pad_token) # ID 0
32
+ self.add_token(self.eos_token) # ID 1
33
+ self.add_token(self.unk_token) # ID 2
34
+
35
+ def add_token(self, token):
36
+ if token not in self.vocab:
37
+ self.vocab[token] = self.vocab_size
38
+ self.inverse_vocab[self.vocab_size] = token
39
+ self.vocab_size += 1
40
+ return True
41
+ return False
42
+
43
+ def build_vocab_from_texts(self, texts, max_vocab_size=10000):
44
+ """Build vocabulary from all training texts"""
45
+ print("Building vocabulary from training data...")
46
+
47
+ # Count all tokens
48
+ token_counter = Counter()
49
+ for text in texts:
50
+ tokens = text.split()
51
+ token_counter.update(tokens)
52
+
53
+ # Add most frequent tokens to vocabulary
54
+ for token, _ in token_counter.most_common(max_vocab_size - self.vocab_size):
55
+ self.add_token(token)
56
+
57
+ print(f"Vocabulary built with {self.vocab_size} tokens")
58
+
59
+ def tokenize(self, text):
60
+ # Simple word-level tokenization
61
+ tokens = text.split()
62
+ token_ids = []
63
+ for token in tokens:
64
+ if token in self.vocab:
65
+ token_ids.append(self.vocab[token])
66
+ else:
67
+ token_ids.append(self.unk_token_id) # Use UNK token for out-of-vocab words
68
+ return token_ids
69
+
70
+ def encode(self, text, max_length=None, padding=False, truncation=False):
71
+ token_ids = self.tokenize(text)
72
+
73
+ if truncation and max_length and len(token_ids) > max_length:
74
+ token_ids = token_ids[:max_length]
75
+
76
+ if padding and max_length and len(token_ids) < max_length:
77
+ token_ids = token_ids + [self.pad_token_id] * (max_length - len(token_ids))
78
+
79
+ return token_ids
80
+
81
+ def decode(self, token_ids):
82
+ # Remove padding tokens for cleaner output
83
+ filtered_ids = [id for id in token_ids if id != self.pad_token_id]
84
+ return " ".join([self.inverse_vocab.get(id, self.unk_token) for id in filtered_ids])
85
+
86
+ class TextDataset(Dataset):
87
+ def __init__(self, texts, tokenizer, max_length=512):
88
+ self.tokenizer = tokenizer
89
+ self.texts = texts
90
+ self.max_length = max_length
91
+
92
+ # Filter out empty texts
93
+ self.texts = [text for text in texts if text.strip()]
94
+
95
+ def __len__(self):
96
+ return len(self.texts)
97
+
98
+ def __getitem__(self, idx):
99
+ text = self.texts[idx]
100
+
101
+ # Ensure text is not empty
102
+ if not text.strip():
103
+ text = " " # Use space for empty text
104
+
105
+ token_ids = self.tokenizer.encode(
106
+ text,
107
+ max_length=self.max_length,
108
+ padding=True,
109
+ truncation=True
110
+ )
111
+
112
+ # Convert to tensor and ensure all IDs are within valid range
113
+ token_ids = [min(id, self.tokenizer.vocab_size - 1) for id in token_ids]
114
+
115
+ return {
116
+ 'input_ids': torch.tensor(token_ids, dtype=torch.long),
117
+ 'labels': torch.tensor(token_ids, dtype=torch.long)
118
+ }
119
+
120
+ class SimpleGPT(nn.Module):
121
+ """A simplified GPT-like model for faster training"""
122
+ def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8, max_seq_len=512):
123
+ super().__init__()
124
+ self.d_model = d_model
125
+ self.vocab_size = vocab_size
126
+ self.max_seq_len = max_seq_len
127
+
128
+ # Token and position embeddings
129
+ self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) # padding_idx=0 for pad token
130
+ self.position_embedding = nn.Embedding(max_seq_len, d_model)
131
+
132
+ # Transformer layers
133
+ encoder_layer = nn.TransformerEncoderLayer(
134
+ d_model=d_model,
135
+ nhead=n_heads,
136
+ dim_feedforward=d_model * 4,
137
+ batch_first=True,
138
+ dropout=0.1
139
+ )
140
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
141
+
142
+ # Output layer with dropout for regularization
143
+ self.dropout = nn.Dropout(0.1)
144
+ self.output_layer = nn.Linear(d_model, vocab_size)
145
+
146
+ # Initialize weights properly
147
+ self.apply(self._init_weights)
148
+
149
+ def _init_weights(self, module):
150
+ if isinstance(module, nn.Linear):
151
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
152
+ if module.bias is not None:
153
+ torch.nn.init.zeros_(module.bias)
154
+ elif isinstance(module, nn.Embedding):
155
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
156
+
157
+ def forward(self, input_ids, labels=None):
158
+ batch_size, seq_len = input_ids.shape
159
+
160
+ # Ensure all token IDs are within valid range
161
+ input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1)
162
+
163
+ # Create token embeddings
164
+ token_embeds = self.token_embedding(input_ids)
165
+
166
+ # Create position embeddings
167
+ positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len)
168
+ position_embeds = self.position_embedding(positions)
169
+
170
+ # Combine embeddings
171
+ x = token_embeds + position_embeds
172
+
173
+ # Create attention mask (ignore padding tokens)
174
+ attention_mask = (input_ids != 0).float()
175
+
176
+ # Transformer with attention mask
177
+ x = self.transformer(x, src_key_padding_mask=attention_mask == 0)
178
+
179
+ # Apply dropout
180
+ x = self.dropout(x)
181
+
182
+ # Output
183
+ logits = self.output_layer(x)
184
+
185
+ # Calculate loss if labels provided
186
+ loss = None
187
+ if labels is not None:
188
+ # Ensure labels are within valid range
189
+ labels = torch.clamp(labels, 0, self.vocab_size - 1)
190
+
191
+ # Create loss mask to ignore padding tokens
192
+ loss_mask = (labels != 0).float()
193
+
194
+ loss_fn = nn.CrossEntropyLoss(ignore_index=0, reduction='none') # ignore padding
195
+ losses = loss_fn(logits.view(-1, self.vocab_size), labels.view(-1))
196
+ loss = (losses * loss_mask.view(-1)).sum() / loss_mask.sum()
197
+
198
+ return {'logits': logits, 'loss': loss}
199
+
200
+ class AITrainerApp:
201
+ def __init__(self):
202
+ # Use simple tokenizer for faster startup
203
+ self.tokenizer = SimpleTokenizer()
204
+ self.model = None
205
+ self.training_data = []
206
+
207
+ # Default model configuration
208
+ self.model_config = {
209
+ "d_model": 512,
210
+ "n_layers": 6,
211
+ "n_heads": 8,
212
+ "max_seq_len": 512
213
+ }
214
+
215
+ # Training control
216
+ self.training_thread = None
217
+ self.stop_training_flag = False
218
+ self.training_status = "Ready - Load training data to begin"
219
+ self.output_log = "Training output will appear here...\n"
220
+
221
+ def get_device(self, device_type="auto"):
222
+ """Get the selected device based on user choice"""
223
+ if device_type == "auto":
224
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
225
+ elif device_type == "cuda":
226
+ if torch.cuda.is_available():
227
+ return torch.device('cuda')
228
+ else:
229
+ return torch.device('cpu')
230
+ else:
231
+ return torch.device('cpu')
232
+
233
+ def log_output(self, message):
234
+ """Add message to output log"""
235
+ self.output_log += message + "\n"
236
+ return self.output_log
237
+
238
+ def verify_model_file(self, file_path):
239
+ """Verify if a model file is valid before loading"""
240
+ try:
241
+ # Simple file checks
242
+ if not os.path.exists(file_path):
243
+ return False, "File does not exist"
244
+
245
+ if os.path.getsize(file_path) < 1024: # Less than 1KB
246
+ return False, "File is too small to be a valid model"
247
+
248
+ return True, "File appears valid"
249
+ except Exception as e:
250
+ return False, f"Error verifying file: {str(e)}"
251
+
252
+ def load_training_files(self, files):
253
+ """Load training files from provided file objects"""
254
+ if not files:
255
+ return "No files selected", self.output_log
256
+
257
+ total_texts = []
258
+ for file_info in files:
259
+ try:
260
+ file_path = file_info.name
261
+ with open(file_path, 'r', encoding='utf-8') as f:
262
+ content = f.read()
263
+ # Split into smaller chunks if needed
264
+ chunks = self.split_into_chunks(content, 1000)
265
+ total_texts.extend(chunks)
266
+ self.output_log = self.log_output(f"Loaded {len(chunks)} chunks from {os.path.basename(file_path)}")
267
+ except Exception as e:
268
+ error_msg = f"Error reading {file_path}: {str(e)}"
269
+ self.output_log = self.log_output(error_msg)
270
+ return error_msg, self.output_log
271
+
272
+ self.training_data.extend(total_texts)
273
+
274
+ # Build vocabulary from all training texts
275
+ self.tokenizer.build_vocab_from_texts(self.training_data, max_vocab_size=10000)
276
+
277
+ status_msg = f"Loaded {len(total_texts)} text chunks from {len(files)} files"
278
+ self.output_log = self.log_output(status_msg)
279
+ self.output_log = self.log_output(f"Vocabulary size: {self.tokenizer.vocab_size}")
280
+
281
+ return status_msg, self.output_log
282
+
283
+ def split_into_chunks(self, text, chunk_size):
284
+ words = text.split()
285
+ chunks = []
286
+ for i in range(0, len(words), chunk_size):
287
+ chunk = ' '.join(words[i:i+chunk_size])
288
+ chunks.append(chunk)
289
+ return chunks
290
+
291
+ def view_training_data(self):
292
+ if not self.training_data:
293
+ return "No training data loaded"
294
+
295
+ preview = ""
296
+ for i, text in enumerate(self.training_data[:50]): # Show first 50 chunks
297
+ preview += f"Chunk {i+1}:\n{text}\n\n{'='*50}\n\n"
298
+
299
+ return preview
300
+
301
+ def start_training(self, d_model, n_layers, n_heads, batch_size, learning_rate, epochs, device_type):
302
+ if not self.training_data:
303
+ self.output_log = self.log_output("Error: No training data loaded!")
304
+ return "Error: No training data loaded!", self.output_log, gr.update(interactive=False)
305
+
306
+ self.stop_training_flag = False
307
+ self.training_status = "Training started..."
308
+ self.output_log = self.log_output("Starting training...")
309
+
310
+ # Update model config from UI
311
+ self.model_config.update({
312
+ "d_model": int(d_model),
313
+ "n_layers": int(n_layers),
314
+ "n_heads": int(n_heads)
315
+ })
316
+
317
+ # Start training in separate thread
318
+ self.training_thread = threading.Thread(
319
+ target=self.train_model,
320
+ args=(int(batch_size), float(learning_rate), int(epochs), device_type)
321
+ )
322
+ self.training_thread.daemon = True
323
+ self.training_thread.start()
324
+
325
+ return "Training started...", self.output_log, gr.update(interactive=False)
326
+
327
+ def stop_training(self):
328
+ self.stop_training_flag = True
329
+ self.training_status = "Stopping training..."
330
+ self.output_log = self.log_output("Stopping training...")
331
+ return "Stopping training...", self.output_log, gr.update(interactive=True)
332
+
333
+ def train_model(self, batch_size, learning_rate, epochs, device_type):
334
+ try:
335
+ # Create dataset and dataloader
336
+ dataset = TextDataset(self.training_data, self.tokenizer)
337
+ dataloader = DataLoader(
338
+ dataset,
339
+ batch_size=batch_size,
340
+ shuffle=True
341
+ )
342
+
343
+ # Initialize model
344
+ self.model = SimpleGPT(
345
+ vocab_size=self.tokenizer.vocab_size,
346
+ d_model=self.model_config["d_model"],
347
+ n_layers=self.model_config["n_layers"],
348
+ n_heads=self.model_config["n_heads"],
349
+ max_seq_len=self.model_config["max_seq_len"]
350
+ )
351
+
352
+ # Setup optimizer
353
+ optimizer = optim.AdamW(
354
+ self.model.parameters(),
355
+ lr=learning_rate
356
+ )
357
+
358
+ # Training loop
359
+ device = self.get_device(device_type)
360
+ self.model.to(device)
361
+ self.output_log = self.log_output(f"Using device: {device}")
362
+
363
+ for epoch in range(epochs):
364
+ if self.stop_training_flag:
365
+ break
366
+
367
+ self.model.train()
368
+ total_loss = 0
369
+ total_batches = 0
370
+
371
+ for batch_idx, batch in enumerate(dataloader):
372
+ if self.stop_training_flag:
373
+ break
374
+
375
+ optimizer.zero_grad()
376
+
377
+ input_ids = batch['input_ids'].to(device)
378
+ labels = batch['labels'].to(device)
379
+
380
+ # Debug: Check for invalid token IDs
381
+ max_id = input_ids.max().item()
382
+ if max_id >= self.tokenizer.vocab_size:
383
+ self.output_log = self.log_output(f"Warning: Found token ID {max_id} but vocab size is {self.tokenizer.vocab_size}")
384
+ # Clamp values to valid range
385
+ input_ids = torch.clamp(input_ids, 0, self.tokenizer.vocab_size - 1)
386
+ labels = torch.clamp(labels, 0, self.tokenizer.vocab_size - 1)
387
+
388
+ outputs = self.model(input_ids=input_ids, labels=labels)
389
+ loss = outputs['loss']
390
+
391
+ if torch.isnan(loss) or torch.isinf(loss):
392
+ self.output_log = self.log_output("Warning: NaN or Inf loss detected, skipping batch")
393
+ continue
394
+
395
+ loss.backward()
396
+
397
+ # Gradient clipping to prevent explosions
398
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
399
+
400
+ optimizer.step()
401
+
402
+ total_loss += loss.item()
403
+ total_batches += 1
404
+
405
+ if batch_idx % 10 == 0:
406
+ status_msg = f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}"
407
+ self.training_status = status_msg
408
+ if batch_idx % 50 == 0: # Log less frequently to avoid UI slowdown
409
+ self.output_log = self.log_output(status_msg)
410
+
411
+ if total_batches > 0:
412
+ avg_loss = total_loss / total_batches
413
+ epoch_msg = f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}"
414
+ self.training_status = epoch_msg
415
+ self.output_log = self.log_output(epoch_msg)
416
+
417
+ if not self.stop_training_flag:
418
+ completion_msg = "Training completed successfully!"
419
+ self.training_status = completion_msg
420
+ self.output_log = self.log_output(completion_msg)
421
+
422
+ except Exception as e:
423
+ error_msg = f"Training error: {str(e)}"
424
+ self.training_status = error_msg
425
+ self.output_log = self.log_output(error_msg)
426
+ import traceback
427
+ self.output_log = self.log_output(traceback.format_exc())
428
+
429
+ finally:
430
+ self.stop_training_flag = False
431
+ # Re-enable the start training button
432
+ return gr.update(interactive=True)
433
+
434
+ def save_model(self, file_path):
435
+ if self.model is None:
436
+ self.output_log = self.log_output("Error: No model to save!")
437
+ return "Error: No model to save!", self.output_log
438
+
439
+ try:
440
+ torch.save({
441
+ 'model_state_dict': self.model.state_dict(),
442
+ 'tokenizer': self.tokenizer,
443
+ 'config': self.model_config,
444
+ 'training_data_info': {
445
+ 'num_chunks': len(self.training_data),
446
+ 'vocab_size': self.tokenizer.vocab_size
447
+ }
448
+ }, file_path)
449
+
450
+ success_msg = f"Model saved to {file_path}"
451
+ self.training_status = success_msg
452
+ self.output_log = self.log_output(success_msg)
453
+ return success_msg, self.output_log
454
+
455
+ except Exception as e:
456
+ error_msg = f"Error saving model: {str(e)}"
457
+ self.output_log = self.log_output(error_msg)
458
+ return error_msg, self.output_log
459
+
460
+ def load_model(self, file_path):
461
+ if not file_path:
462
+ return "No file selected", self.output_log
463
+
464
+ try:
465
+ checkpoint = torch.load(file_path, map_location='cpu')
466
+
467
+ # Recreate the model architecture
468
+ self.model_config = checkpoint['config']
469
+ self.model = SimpleGPT(
470
+ vocab_size=checkpoint['tokenizer'].vocab_size,
471
+ d_model=self.model_config["d_model"],
472
+ n_layers=self.model_config["n_layers"],
473
+ n_heads=self.model_config["n_heads"],
474
+ max_seq_len=self.model_config["max_seq_len"]
475
+ )
476
+
477
+ # Load weights
478
+ self.model.load_state_dict(checkpoint['model_state_dict'])
479
+
480
+ # Load tokenizer
481
+ self.tokenizer = checkpoint['tokenizer']
482
+
483
+ success_msg = f"Model loaded from {file_path}"
484
+ self.training_status = success_msg
485
+ self.output_log = self.log_output(success_msg)
486
+ return success_msg, self.output_log, str(self.model_config['d_model']), str(self.model_config['n_layers']), str(self.model_config['n_heads'])
487
+
488
+ except Exception as e:
489
+ error_msg = f"Error loading model: {str(e)}"
490
+ self.output_log = self.log_output(error_msg)
491
+ return error_msg, self.output_log, gr.update(), gr.update(), gr.update()
492
+
493
+ # Create the app instance
494
+ app = AITrainerApp()
495
+
496
+ # Create Gradio interface
497
+ with gr.Blocks(title="AI Text Generation Trainer") as demo:
498
+ gr.Markdown("# AI Text Generation Trainer")
499
+
500
+ with gr.Row():
501
+ with gr.Column(scale=1):
502
+ gr.Markdown("## Controls")
503
+
504
+ # Data management
505
+ gr.Markdown("### Data Management")
506
+ file_input = gr.File(file_count="multiple", label="Training Files")
507
+ load_btn = gr.Button("Load Text Files")
508
+ view_data_btn = gr.Button("View Training Data")
509
+ data_preview = gr.Textbox(label="Training Data Preview", lines=10, interactive=False)
510
+
511
+ # Device selection
512
+ gr.Markdown("### Device Selection")
513
+ device_type = gr.Radio(
514
+ choices=["auto", "cpu", "cuda"],
515
+ value="auto",
516
+ label="Processing Device"
517
+ )
518
+ device_info = gr.Textbox(
519
+ label="Device Info",
520
+ value=f"GPU available: {'Yes' if torch.cuda.is_available() else 'No'}",
521
+ interactive=False
522
+ )
523
+
524
+ # Model configuration
525
+ gr.Markdown("### Model Configuration")
526
+ d_model = gr.Number(value=512, label="Embedding Size")
527
+ n_layers = gr.Number(value=6, label="Number of Layers")
528
+ n_heads = gr.Number(value=8, label="Number of Heads")
529
+
530
+ # Training parameters
531
+ gr.Markdown("### Training Parameters")
532
+ batch_size = gr.Number(value=4, label="Batch Size")
533
+ learning_rate = gr.Number(value=0.001, label="Learning Rate")
534
+ epochs = gr.Number(value=3, label="Epochs")
535
+
536
+ # Training controls
537
+ gr.Markdown("### Training Control")
538
+ start_btn = gr.Button("Start Training", variant="primary")
539
+ stop_btn = gr.Button("Stop Training")
540
+
541
+ # Export buttons
542
+ gr.Markdown("### Export Model")
543
+ save_path = gr.Textbox(label="Save Path", value="model.pth")
544
+ save_btn = gr.Button("Save Model")
545
+ load_path = gr.Textbox(label="Load Path", value="model.pth")
546
+ load_btn = gr.Button("Load Model")
547
+
548
+ with gr.Column(scale=2):
549
+ gr.Markdown("## Status & Output")
550
+ status = gr.Textbox(label="Status", value=app.training_status, interactive=False)
551
+ output = gr.Textbox(label="Output Log", value=app.output_log, lines=20, interactive=False)
552
+
553
+ # Define event handlers
554
+ load_btn.click(
555
+ app.load_training_files,
556
+ inputs=[file_input],
557
+ outputs=[status, output]
558
+ )
559
+
560
+ view_data_btn.click(
561
+ app.view_training_data,
562
+ inputs=[],
563
+ outputs=[data_preview]
564
+ )
565
+
566
+ start_btn.click(
567
+ app.start_training,
568
+ inputs=[d_model, n_layers, n_heads, batch_size, learning_rate, epochs, device_type],
569
+ outputs=[status, output, start_btn]
570
+ )
571
+
572
+ stop_btn.click(
573
+ app.stop_training,
574
+ inputs=[],
575
+ outputs=[status, output, start_btn]
576
+ )
577
+
578
+ save_btn.click(
579
+ app.save_model,
580
+ inputs=[save_path],
581
+ outputs=[status, output]
582
+ )
583
+
584
+ load_btn.click(
585
+ app.load_model,
586
+ inputs=[load_path],
587
+ outputs=[status, output, d_model, n_layers, n_heads]
588
+ )
589
+
590
+ if __name__ == "__main__":
591
+ demo.launch()