Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import subprocess | |
| import threading | |
| import uuid | |
| import time | |
| import zipfile | |
| import glob | |
| import gzip | |
| import gradio as gr | |
| from transformers import pipeline | |
| # ---- Paths / constants ---- | |
| LOG_FILE = "train.log" | |
| GEN_LOG_FILE = "dataset_gen.log" | |
| MODEL_DIR = "trained_model" | |
| ZIP_FILE = "trained_model.zip" | |
| ZIP_TEMP = ZIP_FILE + ".part" # atomic write to avoid corrupt downloads | |
| # ---- Helpers ---- | |
| def _human_size(nbytes: int) -> str: | |
| units = ["B", "KB", "MB", "GB", "TB"] | |
| i, x = 0, float(nbytes) | |
| while x >= 1024 and i < len(units) - 1: | |
| x /= 1024.0 | |
| i += 1 | |
| return f"{x:.1f} {units[i]}" | |
| def _download_info_text() -> str: | |
| if not os.path.exists(ZIP_FILE): | |
| return "No trained model yet." | |
| size = _human_size(os.path.getsize(ZIP_FILE)) | |
| mtime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(ZIP_FILE))) | |
| return f"*Model ready:* {ZIP_FILE} \n*Size:* {size} \n*Last modified:* {mtime}" | |
| def _read_file_safely(path: str, fallback: str): | |
| if os.path.exists(path): | |
| try: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| except: | |
| return fallback | |
| return fallback | |
| def ensure_clean(): | |
| for p in (ZIP_FILE, ZIP_TEMP): | |
| if os.path.exists(p): | |
| try: | |
| os.remove(p) | |
| except: | |
| pass | |
| def _zip_folder_atomic(src_dir: str, zip_path: str, tmp_path: str): | |
| """Write to .part then rename β avoids corrupt/half-written zips.""" | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| with zipfile.ZipFile(tmp_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| for root, _, files in os.walk(src_dir): | |
| for fn in files: | |
| full = os.path.join(root, fn) | |
| arc = os.path.relpath(full, src_dir) | |
| zf.write(full, arcname=arc) | |
| if os.path.exists(zip_path): | |
| os.remove(zip_path) | |
| os.replace(tmp_path, zip_path) | |
| # ============================================================ | |
| # DATASET GENERATOR (PYTHON) | |
| # ============================================================ | |
| def start_generation(total, shard_size, out_dir, prefix): | |
| """Kick off Python dataset generation in a background thread.""" | |
| total = int(total or 1_000_000) | |
| shard_size = int(shard_size or 10_000) | |
| out_dir = (out_dir or "python_dataset_v1").strip() | |
| prefix = (prefix or "python").strip() | |
| with open(GEN_LOG_FILE, "w") as log: | |
| log.write(f"π§ Generating dataset: total={total}, shard_size={shard_size}, out_dir={out_dir}, prefix={prefix}\n") | |
| def _worker(): | |
| with open(GEN_LOG_FILE, "a") as log: | |
| if not os.path.exists("make_python_dataset.py"): | |
| log.write("β make_python_dataset.py not found in repo root.\n") | |
| return | |
| try: | |
| proc = subprocess.Popen( | |
| [ | |
| "python", | |
| "make_python_dataset.py", | |
| "--total", str(total), | |
| "--shard_size", str(shard_size), | |
| "--out_dir", out_dir, | |
| "--prefix", prefix, | |
| ], | |
| stdout=log, | |
| stderr=subprocess.STDOUT, | |
| ) | |
| proc.wait() | |
| log.write(f"\nπ Generator exited with code {proc.returncode}\n") | |
| if proc.returncode == 0: | |
| files = sorted(glob.glob(os.path.join(out_dir, "*.jsonl.gz"))) | |
| log.write(f"β Done. Shards: {len(files)} in {out_dir}\n") | |
| else: | |
| log.write("β Generation failed.\n") | |
| except Exception as e: | |
| log.write(f"\nβ Exception: {e}\n") | |
| threading.Thread(target=_worker, daemon=True).start() | |
| return f"π Dataset generation started. Output folder: {out_dir}" | |
| def read_gen_logs(): | |
| return _read_file_safely(GEN_LOG_FILE, "Waiting for generator logs...") | |
| def list_shards(folder): | |
| """Return a short preview of shard files (for sanity).""" | |
| if not folder or not os.path.isdir(folder): | |
| return "β Provide a valid folder path that contains .jsonl or .jsonl.gz shards." | |
| jsonl = sorted(glob.glob(os.path.join(folder, "*.jsonl"))) | |
| gz = sorted(glob.glob(os.path.join(folder, "*.jsonl.gz"))) | |
| total = len(jsonl) + len(gz) | |
| if total == 0: | |
| return "No shards found (*.jsonl or *.jsonl.gz)." | |
| preview = (jsonl + gz)[:10] | |
| lines = [f"Found {total} shard(s). Showing first {len(preview)}:"] + [f"- {os.path.basename(p)}" for p in preview] | |
| return "\n".join(lines) | |
| # ============================================================ | |
| # TRAINING | |
| # ============================================================ | |
| def upload_file(file): | |
| """Copy uploaded dataset to a stable path; return status + saved path.""" | |
| if file is None: | |
| return "β No file uploaded.", "" | |
| os.makedirs("uploads", exist_ok=True) | |
| dst = os.path.join("uploads", f"dataset_{uuid.uuid4().hex}.jsonl") | |
| shutil.copy(file.name, dst) | |
| return f"β Uploaded: {os.path.basename(file.name)} β {dst}", dst | |
| def _train_single_file(dataset_path: str, log): | |
| """Train once on a single JSON/JSONL file.""" | |
| proc = subprocess.Popen( | |
| ["python", "train.py", "--dataset", dataset_path, "--output", MODEL_DIR], | |
| stdout=log, | |
| stderr=subprocess.STDOUT, | |
| ) | |
| proc.wait() | |
| log.write(f"\n β³ train.py exited {proc.returncode} for {os.path.basename(dataset_path)}\n") | |
| return proc.returncode == 0 | |
| def _train_worker(dataset_path: str, shards_folder: str): | |
| with open(LOG_FILE, "w") as log: | |
| log.write("π₯ Starting training...\n") | |
| ok = True | |
| with open(LOG_FILE, "a") as log: | |
| if shards_folder: | |
| log.write(f"π Folder mode: {shards_folder}\n") | |
| paths = sorted(glob.glob(os.path.join(shards_folder, "*.jsonl"))) + \ | |
| sorted(glob.glob(os.path.join(shards_folder, "*.jsonl.gz"))) | |
| if not paths: | |
| log.write("β No shards found. Aborting.\n") | |
| ok = False | |
| else: | |
| tmp = "tmp_train.jsonl" | |
| for i, p in enumerate(paths, 1): | |
| log.write(f"\n[{i}/{len(paths)}] Training on shard: {os.path.basename(p)}\n") | |
| # if gz, stream to tmp jsonl | |
| if p.endswith(".gz"): | |
| try: | |
| with gzip.open(p, "rt", encoding="utf-8") as rf, open(tmp, "w", encoding="utf-8") as wf: | |
| for line in rf: | |
| wf.write(line) | |
| shard_path = tmp | |
| except Exception as e: | |
| log.write(f"β Failed to read gz shard: {e}\n") | |
| ok = False | |
| break | |
| else: | |
| shard_path = p | |
| if not _train_single_file(shard_path, log): | |
| ok = False | |
| break | |
| if os.path.exists(tmp): | |
| try: os.remove(tmp) | |
| except: pass | |
| else: | |
| if not dataset_path or not os.path.exists(dataset_path): | |
| log.write("β Please upload a valid dataset first.\n") | |
| ok = False | |
| else: | |
| ok = _train_single_file(dataset_path, log) | |
| if ok and os.path.isdir(MODEL_DIR): | |
| try: | |
| time.sleep(0.5) # settle delay | |
| _zip_folder_atomic(MODEL_DIR, ZIP_FILE, ZIP_TEMP) | |
| sz = _human_size(os.path.getsize(ZIP_FILE)) | |
| log.write(f"\nβ Model zipped β {ZIP_FILE} ({sz})\n") | |
| except Exception as e: | |
| log.write(f"\nβ Zipping failed: {e}\n") | |
| else: | |
| log.write("\nβ Training failed; no zip created.\n") | |
| return ok | |
| def start_training(dataset_path: str, shards_folder: str): | |
| ensure_clean() | |
| threading.Thread(target=_train_worker, args=(dataset_path, shards_folder), daemon=True).start() | |
| return "π Training started in the background. Use the Refresh buttons to update." | |
| def read_logs_once(): | |
| return _read_file_safely(LOG_FILE, "Waiting for logs...") | |
| def check_download(): | |
| """Return download button state + info text (manual, non-streaming).""" | |
| if os.path.exists(ZIP_FILE): | |
| return gr.update(visible=True, value=ZIP_FILE), _download_info_text() | |
| else: | |
| return gr.update(visible=False, value=None), "No trained model yet." | |
| # ============================================================ | |
| # TEST | |
| # ============================================================ | |
| def upload_test_model_zip(zip_file): | |
| """ | |
| Accept a model ZIP, extract to models/test_<uuid>/, return status + extracted path. | |
| ZIP should contain a HF model folder (config.json + tokenizer + weights). | |
| """ | |
| if zip_file is None: | |
| return "β No file uploaded.", "" | |
| extract_root = os.path.join("models", f"test_{uuid.uuid4().hex}") | |
| os.makedirs(extract_root, exist_ok=True) | |
| try: | |
| with zipfile.ZipFile(zip_file.name, "r") as zf: | |
| zf.extractall(extract_root) | |
| return f"β Model ZIP extracted to: {extract_root}", extract_root | |
| except Exception as e: | |
| return f"β Failed to extract: {e}", "" | |
| def clear_uploaded_model(): | |
| return "Model cleared. Will use trained_model/ if available.", "" | |
| def generate_response(prompt, uploaded_model_path): | |
| if not prompt or not prompt.strip(): | |
| return "Please enter a prompt." | |
| try: | |
| if uploaded_model_path and os.path.isdir(uploaded_model_path): | |
| model_path = uploaded_model_path | |
| src = "(uploaded model)" | |
| elif os.path.isdir(MODEL_DIR): | |
| model_path = MODEL_DIR | |
| src = "(trained_model/)" | |
| else: | |
| model_path = "distilgpt2" | |
| src = "(fallback: distilgpt2)" | |
| gen = pipeline("text-generation", model=model_path, tokenizer="distilgpt2") | |
| out = gen(prompt, max_length=256, do_sample=True, temperature=0.7, truncation=True)[0]["generated_text"] | |
| return f"{out}\n\nβ using {src}" | |
| except Exception as e: | |
| return f"β Error: {e}" | |
| # ------------- UI ------------- | |
| with gr.Blocks(title="Python AI Trainer (with Dataset Generator)") as app: | |
| gr.Markdown("## π Python AI Trainer\nGenerate a large Python dataset, train (single file or folder of shards), download the model, and test any model (uploaded or trained).") | |
| dataset_state = gr.State(value="") # path to single dataset file | |
| shard_folder_state = gr.State(value="") # folder containing shards | |
| test_model_state = gr.State(value="") | |
| # =============== Generate Dataset =============== | |
| with gr.Tab("π§ͺ Generate Dataset"): | |
| gr.Markdown("Generate a large Python dataset in shards (no streaming; use Refresh to see logs).") | |
| with gr.Row(): | |
| total_in = gr.Number(value=1_000_000, label="Total samples") | |
| shard_in = gr.Number(value=10_000, label="Rows per shard") | |
| with gr.Row(): | |
| out_dir_in = gr.Textbox(value="python_dataset_v1", label="Output folder") | |
| prefix_in = gr.Textbox(value="python", label="File prefix") | |
| with gr.Row(): | |
| gen_btn = gr.Button("π Start Generation") | |
| gen_refresh_btn = gr.Button("π Refresh Logs") | |
| gen_status = gr.Textbox(label="Generator Status", interactive=False) | |
| gen_logs = gr.Textbox(label="Generator Logs", lines=16) | |
| with gr.Row(): | |
| list_folder = gr.Textbox(value="python_dataset_v1", label="Preview shards in folder") | |
| list_btn = gr.Button("π List Shards") | |
| list_out = gr.Textbox(label="Shard Preview", lines=8) | |
| gen_btn.click( | |
| fn=start_generation, | |
| inputs=[total_in, shard_in, out_dir_in, prefix_in], | |
| outputs=gen_status | |
| ).then(fn=read_gen_logs, outputs=gen_logs) | |
| gen_refresh_btn.click(fn=read_gen_logs, outputs=gen_logs) | |
| list_btn.click(fn=list_shards, inputs=list_folder, outputs=list_out) | |
| # ==================== Train ==================== | |
| with gr.Tab("π§ Train"): | |
| gr.Markdown("Upload a single JSONL *or* provide a folder with shards (.jsonl / .jsonl.gz).") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload single JSONL dataset", file_types=[".jsonl"]) | |
| upload_btn = gr.Button("π€ Upload (single file)") | |
| with gr.Row(): | |
| shards_folder = gr.Textbox(value="", label="Folder with shards (optional)") | |
| use_folder_btn = gr.Button("π Use Folder For Training") | |
| status_box = gr.Textbox(label="Status", interactive=False) | |
| with gr.Row(): | |
| start_btn = gr.Button("π Start Training") | |
| refresh_btn = gr.Button("π Refresh Logs") | |
| refresh_dl_btn = gr.Button("π¦ Refresh Download Area") | |
| log_output = gr.Textbox(label="π Training Logs", lines=18) | |
| with gr.Group(): | |
| gr.Markdown("### π¦ Trained Model") | |
| download_info = gr.Markdown(value="No trained model yet.") | |
| download_btn = gr.DownloadButton(label="π₯ Download Trained Model (.zip)", visible=False, value=None) | |
| upload_btn.click(fn=upload_file, inputs=file_input, outputs=[status_box, dataset_state]) | |
| use_folder_btn.click( | |
| fn=lambda p: ("β Using folder for training." if p.strip() else "β Provide a valid folder path.", p.strip()), | |
| inputs=shards_folder, | |
| outputs=[status_box, shard_folder_state] | |
| ) | |
| start_btn.click( | |
| fn=start_training, | |
| inputs=[dataset_state, shard_folder_state], | |
| outputs=status_box | |
| ).then(fn=read_logs_once, outputs=log_output | |
| ).then(fn=check_download, outputs=[download_btn, download_info]) | |
| refresh_btn.click(fn=read_logs_once, outputs=log_output) | |
| refresh_dl_btn.click(fn=check_download, outputs=[download_btn, download_info]) | |
| # ===================== Test ===================== | |
| with gr.Tab("π Test"): | |
| gr.Markdown("Use an uploaded model ZIP or the just-trained model.") | |
| with gr.Row(): | |
| test_zip = gr.File(label="Upload Model ZIP", file_types=[".zip"]) | |
| load_test_btn = gr.Button("π¦ Load Uploaded Model ZIP") | |
| clear_test_btn = gr.Button("π§Ή Clear Uploaded Model") | |
| test_status = gr.Textbox(label="Test Model Status", interactive=False) | |
| prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., Write a Python function that parses CSV and computes average") | |
| test_btn = gr.Button("π Generate") | |
| response_output = gr.Textbox(label="AI Response", lines=12) | |
| load_test_btn.click(fn=upload_test_model_zip, inputs=test_zip, outputs=[test_status, test_model_state]) | |
| clear_test_btn.click(fn=clear_uploaded_model, outputs=[test_status, test_model_state]) | |
| test_btn.click(fn=generate_response, inputs=[prompt_input, test_model_state], outputs=response_output) | |
| # ---- Optional: auto-start on boot via env vars ---- | |
| AUTOSTART = os.getenv("AUTOSTART_TRAIN", "0") == "1" | |
| AUTOSTART_SINGLE_DATASET = os.getenv("AUTOSTART_DATASET", "").strip() | |
| AUTOSTART_SHARDS_FOLDER = os.getenv("AUTOSTART_SHARDS", "").strip() | |
| if AUTOSTART and not os.path.exists(".autostart.started"): | |
| open(".autostart.started", "w").close() | |
| try: | |
| _ = start_training(AUTOSTART_SINGLE_DATASET if AUTOSTART_SINGLE_DATASET else "", | |
| AUTOSTART_SHARDS_FOLDER if AUTOSTART_SHARDS_FOLDER else "") | |
| _ = read_logs_once() | |
| except Exception as e: | |
| with open(LOG_FILE, "a") as log: | |
| log.write(f"\nβ Autostart failed: {e}\n") | |
| app.launch() |