Percy3822 commited on
Commit
22b8213
Β·
verified Β·
1 Parent(s): cc4f041

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -94
app.py CHANGED
@@ -1,96 +1,83 @@
1
  import os
2
- import time
3
- import zipfile
4
- import gradio as gr
5
  import subprocess
6
- import threading
7
-
8
- # Paths
9
- OUTPUT_DIR = "train_output"
10
- ZIP_FILE = "python_ai_trained_model.zip"
11
- LOG_FILE = "train_log.txt"
12
-
13
- # Zip function
14
- def zip_trained_model():
15
- with zipfile.ZipFile(ZIP_FILE, 'w', zipfile.ZIP_DEFLATED) as zipf:
16
- for root, _, files in os.walk(OUTPUT_DIR):
17
- for file in files:
18
- filepath = os.path.join(root, file)
19
- arcname = os.path.relpath(filepath, OUTPUT_DIR)
20
- zipf.write(filepath, arcname)
21
-
22
- # Tail logs function
23
- def tail_logs(n=20):
24
- if not os.path.exists(LOG_FILE):
25
- return ""
26
- with open(LOG_FILE, 'r') as f:
27
- return ''.join(f.readlines()[-n:])
28
-
29
- # Background training runner
30
- def run_training(status_box, time_box, download_file, model_size_box, log_box):
31
- start_time = time.time()
32
- status_box.value = "πŸš€ Training started..."
33
- time_box.value = "Training in progress..."
34
- log_box.value = ""
35
-
36
- # Create log file
37
- with open(LOG_FILE, "w") as log:
38
- log.write("πŸš€ Launching train.py...\n")
39
-
40
- # Start training
41
- with open(LOG_FILE, "a") as log:
42
- process = subprocess.Popen(["python", "train.py"], stdout=log, stderr=subprocess.STDOUT)
43
- while process.poll() is None:
44
- log_box.update(value=tail_logs())
45
- time.sleep(5)
46
-
47
- # Check exit status
48
- if process.returncode != 0:
49
- status_box.value = f"❌ Training failed with exit code {process.returncode}"
50
- log_box.value = tail_logs()
51
- return
52
-
53
- # Training success
54
- elapsed = round(time.time() - start_time, 2)
55
- time_box.value = f"βœ… Completed in {elapsed // 60:.0f} min {elapsed % 60:.0f} sec"
56
- status_box.value = "πŸ”„ Compressing model..."
57
-
58
- # Zip it
59
- zip_trained_model()
60
- size_mb = round(os.path.getsize(ZIP_FILE) / (1024 * 1024), 2)
61
- model_size_box.value = f"πŸ“¦ Model Size: {size_mb} MB"
62
-
63
- download_file.value = ZIP_FILE
64
- download_file.visible = True
65
- status_box.value = "βœ… Training complete. Download below."
66
-
67
- # Button trigger
68
- def start_training(status_box, time_box, download_file, model_size_box, log_box):
69
- thread = threading.Thread(target=run_training, args=(status_box, time_box, download_file, model_size_box, log_box))
70
- thread.start()
71
- return "Training process started."
72
-
73
- # Gradio UI
74
- with gr.Blocks() as demo:
75
- gr.Markdown("## 🧠 Python AI Trainer (StarCoder 7B)")
76
- gr.Markdown("Train your Python AI with 1 click. Watch logs. Download model when done.")
77
-
78
- with gr.Row():
79
- train_btn = gr.Button("πŸš€ Start Training")
80
- status_box = gr.Textbox(label="Status", value="Ready", interactive=False)
81
-
82
- with gr.Row():
83
- time_box = gr.Textbox(label="Training Time", interactive=False)
84
- model_size_box = gr.Textbox(label="Final Model Size", interactive=False)
85
-
86
- log_box = gr.Textbox(label="Live Training Logs", lines=20, interactive=False, value="")
87
- download_file = gr.File(label="πŸ“₯ Download Trained Model (.zip)", visible=False)
88
-
89
- train_btn.click(
90
- fn=start_training,
91
- inputs=[status_box, time_box, download_file, model_size_box, log_box],
92
- outputs=[status_box]
93
- )
94
-
95
-
96
- demo.launch()
 
1
  import os
2
+ import shutil
 
 
3
  import subprocess
4
+ import gradio as gr
5
+ from transformers import pipeline
6
+
7
+ uploaded_file_path = "dataset.jsonl"
8
+ log_path = "train.log"
9
+ model_dir = "trained_model"
10
+ zip_file = "trained_model.zip"
11
+
12
+ # Try loading the generator model once at the top (for better performance)
13
+ try:
14
+ generator = pipeline("text-generation", model=model_dir, tokenizer="distilgpt2")
15
+ except:
16
+ generator = None
17
+
18
+ def upload_file(file):
19
+ if file is None:
20
+ return "No file uploaded."
21
+ shutil.copy(file.name, uploaded_file_path)
22
+ return "βœ… File uploaded successfully."
23
+
24
+ def start_training():
25
+ with open(log_path, "w") as log_file:
26
+ process = subprocess.Popen(
27
+ ["python", "train.py", "--dataset", uploaded_file_path],
28
+ stdout=log_file,
29
+ stderr=subprocess.STDOUT
30
+ )
31
+ process.wait()
32
+
33
+ if os.path.exists(model_dir):
34
+ shutil.make_archive("trained_model", "zip", model_dir)
35
+ return "βœ… Training complete!", zip_file
36
+ else:
37
+ return "❌ Training failed.", None
38
+
39
+ def read_logs():
40
+ if os.path.exists(log_path):
41
+ with open(log_path, "r") as f:
42
+ return f.read()
43
+ return "Waiting for logs..."
44
+
45
+ def generate_response(prompt):
46
+ try:
47
+ if generator is None:
48
+ return "❌ Model not loaded. Please train or upload a valid model."
49
+ result = generator(
50
+ prompt,
51
+ max_length=256,
52
+ do_sample=True,
53
+ temperature=0.7,
54
+ truncation=True # βœ… Fix warning and enforce consistent length
55
+ )[0]["generated_text"]
56
+ return result
57
+ except Exception as e:
58
+ return f"❌ Error: {e}"
59
+
60
+ # === UI ===
61
+ with gr.Blocks() as app:
62
+ with gr.Tab("🧠 Train AI"):
63
+ gr.Markdown("## πŸ“₯ Upload your dataset and 🎯 train a Godot AI")
64
+ file_input = gr.File(label="Upload JSONL Dataset")
65
+ upload_btn = gr.Button("Upload")
66
+ status_box = gr.Textbox(label="Upload Status")
67
+
68
+ start_btn = gr.Button("πŸš€ Start Training")
69
+ log_output = gr.Textbox(label="πŸ“œ Training Logs", lines=15)
70
+ download_btn = gr.File(label="πŸ“₯ Download Trained Model", visible=False)
71
+
72
+ upload_btn.click(fn=upload_file, inputs=file_input, outputs=status_box)
73
+ start_btn.click(fn=start_training, outputs=[status_box, download_btn])
74
+ start_btn.click(fn=read_logs, outputs=log_output)
75
+
76
+ with gr.Tab("πŸš€ Test AI"):
77
+ gr.Markdown("## πŸ’‘ Try your trained Godot AI below")
78
+ prompt_input = gr.Textbox(label="Enter Prompt")
79
+ test_btn = gr.Button("πŸ” Test AI")
80
+ response_output = gr.Textbox(label="AI Response", lines=10)
81
+ test_btn.click(fn=generate_response, inputs=prompt_input, outputs=response_output)
82
+
83
+ app.launch()