Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import subprocess | |
| import gradio as gr | |
| from transformers import pipeline | |
| uploaded_file_path = "dataset.jsonl" | |
| log_path = "train.log" | |
| model_dir = "trained_model" | |
| zip_file = "trained_model.zip" | |
| # Try loading the generator model once at the top (for better performance) | |
| try: | |
| generator = pipeline("text-generation", model=model_dir, tokenizer="distilgpt2") | |
| except: | |
| generator = None | |
| def upload_file(file): | |
| if file is None: | |
| return "No file uploaded." | |
| shutil.copy(file.name, uploaded_file_path) | |
| return "β File uploaded successfully." | |
| def start_training(): | |
| with open(log_path, "w") as log_file: | |
| process = subprocess.Popen( | |
| ["python", "train.py", "--dataset", uploaded_file_path], | |
| stdout=log_file, | |
| stderr=subprocess.STDOUT | |
| ) | |
| process.wait() | |
| if os.path.exists(model_dir): | |
| shutil.make_archive("trained_model", "zip", model_dir) | |
| return "β Training complete!", zip_file | |
| else: | |
| return "β Training failed.", None | |
| def read_logs(): | |
| if os.path.exists(log_path): | |
| with open(log_path, "r") as f: | |
| return f.read() | |
| return "Waiting for logs..." | |
| def generate_response(prompt): | |
| try: | |
| if generator is None: | |
| return "β Model not loaded. Please train or upload a valid model." | |
| result = generator( | |
| prompt, | |
| max_length=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| truncation=True # β Fix warning and enforce consistent length | |
| )[0]["generated_text"] | |
| return result | |
| except Exception as e: | |
| return f"β Error: {e}" | |
| # === UI === | |
| with gr.Blocks() as app: | |
| with gr.Tab("π§ Train AI"): | |
| gr.Markdown("## π₯ Upload your dataset and π― train a Godot AI") | |
| file_input = gr.File(label="Upload JSONL Dataset") | |
| upload_btn = gr.Button("Upload") | |
| status_box = gr.Textbox(label="Upload Status") | |
| start_btn = gr.Button("π Start Training") | |
| log_output = gr.Textbox(label="π Training Logs", lines=15) | |
| download_btn = gr.File(label="π₯ Download Trained Model", visible=False) | |
| upload_btn.click(fn=upload_file, inputs=file_input, outputs=status_box) | |
| start_btn.click(fn=start_training, outputs=[status_box, download_btn]) | |
| start_btn.click(fn=read_logs, outputs=log_output) | |
| with gr.Tab("π Test AI"): | |
| gr.Markdown("## π‘ Try your trained Godot AI below") | |
| prompt_input = gr.Textbox(label="Enter Prompt") | |
| test_btn = gr.Button("π Test AI") | |
| response_output = gr.Textbox(label="AI Response", lines=10) | |
| test_btn.click(fn=generate_response, inputs=prompt_input, outputs=response_output) | |
| app.launch() |