File size: 2,780 Bytes
ac2c6e9
22b8213
e865ed1
22b8213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import shutil
import subprocess
import gradio as gr
from transformers import pipeline

uploaded_file_path = "dataset.jsonl"
log_path = "train.log"
model_dir = "trained_model"
zip_file = "trained_model.zip"

# Try loading the generator model once at the top (for better performance)
try:
    generator = pipeline("text-generation", model=model_dir, tokenizer="distilgpt2")
except:
    generator = None

def upload_file(file):
    if file is None:
        return "No file uploaded."
    shutil.copy(file.name, uploaded_file_path)
    return "βœ… File uploaded successfully."

def start_training():
    with open(log_path, "w") as log_file:
        process = subprocess.Popen(
            ["python", "train.py", "--dataset", uploaded_file_path],
            stdout=log_file,
            stderr=subprocess.STDOUT
        )
        process.wait()

    if os.path.exists(model_dir):
        shutil.make_archive("trained_model", "zip", model_dir)
        return "βœ… Training complete!", zip_file
    else:
        return "❌ Training failed.", None

def read_logs():
    if os.path.exists(log_path):
        with open(log_path, "r") as f:
            return f.read()
    return "Waiting for logs..."

def generate_response(prompt):
    try:
        if generator is None:
            return "❌ Model not loaded. Please train or upload a valid model."
        result = generator(
            prompt,
            max_length=256,
            do_sample=True,
            temperature=0.7,
            truncation=True  # βœ… Fix warning and enforce consistent length
        )[0]["generated_text"]
        return result
    except Exception as e:
        return f"❌ Error: {e}"

# === UI ===
with gr.Blocks() as app:
    with gr.Tab("🧠 Train AI"):
        gr.Markdown("## πŸ“₯ Upload your dataset and 🎯 train a Godot AI")
        file_input = gr.File(label="Upload JSONL Dataset")
        upload_btn = gr.Button("Upload")
        status_box = gr.Textbox(label="Upload Status")

        start_btn = gr.Button("πŸš€ Start Training")
        log_output = gr.Textbox(label="πŸ“œ Training Logs", lines=15)
        download_btn = gr.File(label="πŸ“₯ Download Trained Model", visible=False)

        upload_btn.click(fn=upload_file, inputs=file_input, outputs=status_box)
        start_btn.click(fn=start_training, outputs=[status_box, download_btn])
        start_btn.click(fn=read_logs, outputs=log_output)

    with gr.Tab("πŸš€ Test AI"):
        gr.Markdown("## πŸ’‘ Try your trained Godot AI below")
        prompt_input = gr.Textbox(label="Enter Prompt")
        test_btn = gr.Button("πŸ” Test AI")
        response_output = gr.Textbox(label="AI Response", lines=10)
        test_btn.click(fn=generate_response, inputs=prompt_input, outputs=response_output)

app.launch()