Spaces:
Sleeping
Sleeping
File size: 2,780 Bytes
ac2c6e9 22b8213 e865ed1 22b8213 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import shutil
import subprocess
import gradio as gr
from transformers import pipeline
uploaded_file_path = "dataset.jsonl"
log_path = "train.log"
model_dir = "trained_model"
zip_file = "trained_model.zip"
# Try loading the generator model once at the top (for better performance)
try:
generator = pipeline("text-generation", model=model_dir, tokenizer="distilgpt2")
except:
generator = None
def upload_file(file):
if file is None:
return "No file uploaded."
shutil.copy(file.name, uploaded_file_path)
return "β
File uploaded successfully."
def start_training():
with open(log_path, "w") as log_file:
process = subprocess.Popen(
["python", "train.py", "--dataset", uploaded_file_path],
stdout=log_file,
stderr=subprocess.STDOUT
)
process.wait()
if os.path.exists(model_dir):
shutil.make_archive("trained_model", "zip", model_dir)
return "β
Training complete!", zip_file
else:
return "β Training failed.", None
def read_logs():
if os.path.exists(log_path):
with open(log_path, "r") as f:
return f.read()
return "Waiting for logs..."
def generate_response(prompt):
try:
if generator is None:
return "β Model not loaded. Please train or upload a valid model."
result = generator(
prompt,
max_length=256,
do_sample=True,
temperature=0.7,
truncation=True # β
Fix warning and enforce consistent length
)[0]["generated_text"]
return result
except Exception as e:
return f"β Error: {e}"
# === UI ===
with gr.Blocks() as app:
with gr.Tab("π§ Train AI"):
gr.Markdown("## π₯ Upload your dataset and π― train a Godot AI")
file_input = gr.File(label="Upload JSONL Dataset")
upload_btn = gr.Button("Upload")
status_box = gr.Textbox(label="Upload Status")
start_btn = gr.Button("π Start Training")
log_output = gr.Textbox(label="π Training Logs", lines=15)
download_btn = gr.File(label="π₯ Download Trained Model", visible=False)
upload_btn.click(fn=upload_file, inputs=file_input, outputs=status_box)
start_btn.click(fn=start_training, outputs=[status_box, download_btn])
start_btn.click(fn=read_logs, outputs=log_output)
with gr.Tab("π Test AI"):
gr.Markdown("## π‘ Try your trained Godot AI below")
prompt_input = gr.Textbox(label="Enter Prompt")
test_btn = gr.Button("π Test AI")
response_output = gr.Textbox(label="AI Response", lines=10)
test_btn.click(fn=generate_response, inputs=prompt_input, outputs=response_output)
app.launch() |