|
|
import os |
|
|
import shutil |
|
|
import subprocess |
|
|
import threading |
|
|
import uuid |
|
|
import time |
|
|
import zipfile |
|
|
import gzip |
|
|
import glob |
|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
|
|
|
LOG_FILE = "train.log" |
|
|
GEN_LOG_FILE = "dataset_gen.log" |
|
|
MODEL_DIR = "trained_model" |
|
|
ZIP_FILE = "trained_model.zip" |
|
|
ZIP_TEMP = ZIP_FILE + ".part" |
|
|
|
|
|
def _human_size(nbytes: int) -> str: |
|
|
units = ["B","KB","MB","GB","TB"]; i=0; x=float(nbytes) |
|
|
while x>=1024 and i<len(units)-1: x/=1024.0; i+=1 |
|
|
return f"{x:.1f} {units[i]}" |
|
|
|
|
|
def _read_file_safely(path: str, fallback: str): |
|
|
if os.path.exists(path): |
|
|
try: |
|
|
with open(path,"r",encoding="utf-8",errors="ignore") as f: return f.read() |
|
|
except: return fallback |
|
|
return fallback |
|
|
|
|
|
def _zip_folder_atomic(src_dir: str, zip_path: str, tmp_path: str): |
|
|
if os.path.exists(tmp_path): os.remove(tmp_path) |
|
|
with zipfile.ZipFile(tmp_path,"w",compression=zipfile.ZIP_DEFLATED) as zf: |
|
|
for root,_,files in os.walk(src_dir): |
|
|
for fn in files: |
|
|
full=os.path.join(root,fn); arc=os.path.relpath(full,src_dir) |
|
|
zf.write(full,arcname=arc) |
|
|
if os.path.exists(zip_path): os.remove(zip_path) |
|
|
os.replace(tmp_path,zip_path) |
|
|
|
|
|
def _download_info_text() -> str: |
|
|
if not os.path.exists(ZIP_FILE): return "No trained model yet." |
|
|
size=_human_size(os.path.getsize(ZIP_FILE)) |
|
|
mtime=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(ZIP_FILE))) |
|
|
return f"*Model ready:* {ZIP_FILE} \n*Size:* {size} \n*Last modified:* {mtime}" |
|
|
|
|
|
def ensure_clean_zip(): |
|
|
for p in (ZIP_FILE, ZIP_TEMP): |
|
|
if os.path.exists(p): |
|
|
try: os.remove(p) |
|
|
except: pass |
|
|
|
|
|
|
|
|
def start_generation(total, shard_size, out_dir, prefix): |
|
|
total=int(total or 1_000_000) |
|
|
shard_size=int(shard_size or 10_000) |
|
|
out_dir=(out_dir or "json_dataset_v1").strip() |
|
|
prefix=(prefix or "json").strip() |
|
|
with open(GEN_LOG_FILE,"w") as log: |
|
|
log.write(f"π§ Generating dataset: total={total}, shard_size={shard_size}, out_dir={out_dir}, prefix={prefix}\n") |
|
|
def _worker(): |
|
|
with open(GEN_LOG_FILE,"a") as log: |
|
|
if not os.path.exists("make_json_dataset.py"): |
|
|
log.write("β make_json_dataset.py not found.\n"); return |
|
|
try: |
|
|
p = subprocess.Popen( |
|
|
["python","make_json_dataset.py", |
|
|
"--total",str(total), |
|
|
"--shard_size",str(shard_size), |
|
|
"--out_dir",out_dir, |
|
|
"--prefix",prefix], |
|
|
stdout=log, stderr=subprocess.STDOUT |
|
|
) |
|
|
p.wait() |
|
|
log.write(f"\nπ Generator exited with code {p.returncode}\n") |
|
|
if p.returncode==0: |
|
|
files = sorted(glob.glob(os.path.join(out_dir,"*.jsonl.gz"))) |
|
|
log.write(f"β
Done. Shards: {len(files)} in {out_dir}\n") |
|
|
else: |
|
|
log.write("β Generation failed.\n") |
|
|
except Exception as e: |
|
|
log.write(f"\nβ Exception: {e}\n") |
|
|
threading.Thread(target=_worker, daemon=True).start() |
|
|
return f"π Dataset generation started. Output folder: {out_dir}" |
|
|
|
|
|
def read_gen_logs(): |
|
|
return _read_file_safely(GEN_LOG_FILE,"Waiting for generator logs...") |
|
|
|
|
|
def list_shards(folder): |
|
|
if not folder or not os.path.isdir(folder): return "β Provide a valid folder path." |
|
|
|
|
|
jsonl = sorted(glob.glob(os.path.join(folder,"*.jsonl"))) |
|
|
gz = sorted(glob.glob(os.path.join(folder,"*.jsonl.gz"))) |
|
|
files = [p for p in (jsonl+gz) if "manifest" not in os.path.basename(p).lower()] |
|
|
total = len(files) |
|
|
if total==0: return "No shards found (*.jsonl / *.jsonl.gz)." |
|
|
preview=files[:10] |
|
|
lines=[f"Found {total} shard(s). Showing first {len(preview)}:"]+[f"- {os.path.basename(p)}" for p in preview] |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def upload_file(file): |
|
|
if file is None: return "β No file uploaded.", "" |
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
dst = os.path.join("uploads", f"dataset_{uuid.uuid4().hex}.jsonl") |
|
|
shutil.copy(file.name, dst) |
|
|
return f"β
Uploaded: {os.path.basename(file.name)} β {dst}", dst |
|
|
|
|
|
def _train_single_file(dataset_path: str, log): |
|
|
p = subprocess.Popen(["python","train.py","--dataset",dataset_path,"--output",MODEL_DIR], |
|
|
stdout=log, stderr=subprocess.STDOUT) |
|
|
p.wait() |
|
|
log.write(f"\n β³ train.py exited {p.returncode} for {os.path.basename(dataset_path)}\n") |
|
|
return p.returncode==0 |
|
|
|
|
|
def _train_worker(dataset_path: str, shards_folder: str): |
|
|
with open(LOG_FILE,"w") as log: log.write("π₯ Starting training (JSON AI)β¦\n") |
|
|
ok=True |
|
|
with open(LOG_FILE,"a") as log: |
|
|
if shards_folder: |
|
|
log.write(f"π Folder mode: {shards_folder}\n") |
|
|
|
|
|
paths = sorted(glob.glob(os.path.join(shards_folder,"*.jsonl"))) + \ |
|
|
sorted(glob.glob(os.path.join(shards_folder,"*.jsonl.gz"))) |
|
|
paths = [p for p in paths if "manifest" not in os.path.basename(p).lower()] |
|
|
if not paths: |
|
|
log.write("β No shards found (*.jsonl / *.jsonl.gz). Aborting.\n"); ok=False |
|
|
else: |
|
|
tmp="tmp_train.jsonl" |
|
|
for i,pth in enumerate(paths,1): |
|
|
log.write(f"\n[{i}/{len(paths)}] Training on shard: {os.path.basename(pth)}\n") |
|
|
if pth.endswith(".gz"): |
|
|
try: |
|
|
with gzip.open(pth,"rt",encoding="utf-8") as rf, open(tmp,"w",encoding="utf-8") as wf: |
|
|
for line in rf: wf.write(line) |
|
|
shard=tmp |
|
|
except Exception as e: |
|
|
log.write(f"β Failed to read gz shard: {e}\n"); ok=False; break |
|
|
else: |
|
|
shard=pth |
|
|
if not _train_single_file(shard, log): |
|
|
ok=False; break |
|
|
if os.path.exists(tmp): |
|
|
try: os.remove(tmp) |
|
|
except: pass |
|
|
else: |
|
|
if not dataset_path or not os.path.exists(dataset_path): |
|
|
log.write("β Please upload a valid dataset first.\n"); ok=False |
|
|
else: |
|
|
ok=_train_single_file(dataset_path, log) |
|
|
|
|
|
if ok and os.path.isdir(MODEL_DIR): |
|
|
try: |
|
|
time.sleep(0.5) |
|
|
_zip_folder_atomic(MODEL_DIR, ZIP_FILE, ZIP_TEMP) |
|
|
sz=_human_size(os.path.getsize(ZIP_FILE)) |
|
|
log.write(f"\nβ
Model zipped β {ZIP_FILE} ({sz})\n") |
|
|
except Exception as e: |
|
|
log.write(f"\nβ Zipping failed: {e}\n") |
|
|
else: |
|
|
log.write("\nβ Training failed; no zip created.\n") |
|
|
|
|
|
def start_training(dataset_path: str, shards_folder: str): |
|
|
ensure_clean_zip() |
|
|
threading.Thread(target=_train_worker, args=(dataset_path, shards_folder), daemon=True).start() |
|
|
return "π Training started in the background. Use the Refresh buttons to update." |
|
|
|
|
|
def read_logs_once(): |
|
|
return _read_file_safely(LOG_FILE,"Waiting for logs...") |
|
|
|
|
|
def check_download(): |
|
|
if os.path.exists(ZIP_FILE): |
|
|
return gr.update(visible=True, value=ZIP_FILE), _download_info_text() |
|
|
else: |
|
|
return gr.update(visible=False, value=None), "No trained model yet." |
|
|
|
|
|
|
|
|
def upload_test_model_zip(zip_file): |
|
|
if zip_file is None: return "β No file uploaded.", "" |
|
|
extract_root = os.path.join("models", f"test_{uuid.uuid4().hex}") |
|
|
os.makedirs(extract_root, exist_ok=True) |
|
|
try: |
|
|
with zipfile.ZipFile(zip_file.name,"r") as zf: zf.extractall(extract_root) |
|
|
return f"β
Model ZIP extracted to: {extract_root}", extract_root |
|
|
except Exception as e: |
|
|
return f"β Failed to extract: {e}", "" |
|
|
|
|
|
def clear_uploaded_model(): |
|
|
return "Model cleared. Will use trained_model/ if available.", "" |
|
|
|
|
|
def generate_response(prompt, uploaded_model_path): |
|
|
if not prompt or not prompt.strip(): return "Please enter a prompt." |
|
|
try: |
|
|
if uploaded_model_path and os.path.isdir(uploaded_model_path): |
|
|
model_path, src = uploaded_model_path, "(uploaded model)" |
|
|
elif os.path.isdir(MODEL_DIR): |
|
|
model_path, src = MODEL_DIR, "(trained_model/)" |
|
|
else: |
|
|
model_path, src = "distilgpt2", "(fallback: distilgpt2)" |
|
|
gen = pipeline("text-generation", model=model_path, tokenizer="distilgpt2") |
|
|
out = gen(prompt, max_length=256, do_sample=True, temperature=0.7, truncation=True)[0]["generated_text"] |
|
|
return f"{out}\n\nβ using {src}" |
|
|
except Exception as e: |
|
|
return f"β Error: {e}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="JSON AI Trainer (with Dataset Generator)") as app: |
|
|
gr.Markdown("## π§© JSON AI Trainer\nGenerate a large JSON dataset, train (single file or folder of shards), download the model, and test.") |
|
|
|
|
|
dataset_state = gr.State(value="") |
|
|
shard_folder_state = gr.State(value="") |
|
|
test_model_state = gr.State(value="") |
|
|
|
|
|
with gr.Tab("π§ͺ Generate Dataset"): |
|
|
with gr.Row(): |
|
|
total_in = gr.Number(value=1_000_000, label="Total samples") |
|
|
shard_in = gr.Number(value=10_000, label="Rows per shard") |
|
|
with gr.Row(): |
|
|
out_dir_in = gr.Textbox(value="json_dataset_v1", label="Output folder") |
|
|
prefix_in = gr.Textbox(value="json", label="File prefix") |
|
|
with gr.Row(): |
|
|
gen_btn = gr.Button("π Start Generation") |
|
|
gen_refresh_btn = gr.Button("π Refresh Logs") |
|
|
gen_status = gr.Textbox(label="Generator Status", interactive=False) |
|
|
gen_logs = gr.Textbox(label="Generator Logs", lines=16) |
|
|
with gr.Row(): |
|
|
list_folder = gr.Textbox(value="json_dataset_v1", label="Preview shards in folder") |
|
|
list_btn = gr.Button("π List Shards") |
|
|
list_out = gr.Textbox(label="Shard Preview", lines=8) |
|
|
|
|
|
gen_btn.click(fn=start_generation, inputs=[total_in, shard_in, out_dir_in, prefix_in], outputs=gen_status |
|
|
).then(fn=read_gen_logs, outputs=gen_logs) |
|
|
gen_refresh_btn.click(fn=read_gen_logs, outputs=gen_logs) |
|
|
list_btn.click(fn=list_shards, inputs=list_folder, outputs=list_out) |
|
|
|
|
|
with gr.Tab("π§ Train"): |
|
|
gr.Markdown("Upload a single JSON/JSONL file *or* train on a folder of shards (.jsonl, .jsonl.gz). Manifests are ignored.") |
|
|
with gr.Row(): |
|
|
file_input = gr.File(label="Upload single dataset file", file_types=[".json",".jsonl"]) |
|
|
upload_btn = gr.Button("π€ Upload (single file)") |
|
|
with gr.Row(): |
|
|
shards_folder = gr.Textbox(value="", label="Folder with shards (optional)") |
|
|
use_folder_btn = gr.Button("π Use Folder For Training") |
|
|
status_box = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
start_btn = gr.Button("π Start Training") |
|
|
refresh_btn = gr.Button("π Refresh Logs") |
|
|
refresh_dl_btn = gr.Button("π¦ Refresh Download Area") |
|
|
log_output = gr.Textbox(label="π Training Logs", lines=18) |
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### π¦ Trained Model") |
|
|
download_info = gr.Markdown(value="No trained model yet.") |
|
|
download_btn = gr.DownloadButton(label="π₯ Download Trained Model (.zip)", visible=False, value=None) |
|
|
|
|
|
upload_btn.click(fn=upload_file, inputs=file_input, outputs=[status_box, dataset_state]) |
|
|
use_folder_btn.click(fn=lambda p: ("β
Using folder for training." if p.strip() else "β Provide a valid folder path.", p.strip()), |
|
|
inputs=shards_folder, outputs=[status_box, shard_folder_state]) |
|
|
start_btn.click(fn=start_training, inputs=[dataset_state, shard_folder_state], outputs=status_box |
|
|
).then(fn=read_logs_once, outputs=log_output |
|
|
).then(fn=check_download, outputs=[download_btn, download_info]) |
|
|
|
|
|
refresh_btn.click(fn=read_logs_once, outputs=log_output) |
|
|
refresh_dl_btn.click(fn=check_download, outputs=[download_btn, download_info]) |
|
|
|
|
|
with gr.Tab("π Test"): |
|
|
gr.Markdown("Upload a model ZIP or use the just-trained model.") |
|
|
with gr.Row(): |
|
|
test_zip = gr.File(label="Upload Model ZIP", file_types=[".zip"]) |
|
|
load_test_btn = gr.Button("π¦ Load Uploaded Model ZIP") |
|
|
clear_test_btn = gr.Button("π§Ή Clear Uploaded Model") |
|
|
test_status = gr.Textbox(label="Test Model Status", interactive=False) |
|
|
prompt_input = gr.Textbox(label="Prompt", placeholder='e.g., "Generate JSON Schema for an invoice" or "Fix this JSON: {\'a\':1,}"') |
|
|
test_btn = gr.Button("π Generate") |
|
|
response_output = gr.Textbox(label="AI Response", lines=12) |
|
|
|
|
|
load_test_btn.click(fn=upload_test_model_zip, inputs=test_zip, outputs=[test_status, test_model_state]) |
|
|
clear_test_btn.click(fn=clear_uploaded_model, outputs=[test_status, test_model_state]) |
|
|
test_btn.click(fn=generate_response, inputs=[prompt_input, test_model_state], outputs=response_output) |
|
|
|
|
|
|
|
|
AUTOSTART = os.getenv("AUTOSTART_TRAIN","0") == "1" |
|
|
AUTOSTART_DATASET = os.getenv("AUTOSTART_DATASET","").strip() |
|
|
AUTOSTART_SHARDS = os.getenv("AUTOSTART_SHARDS","").strip() |
|
|
if AUTOSTART and not os.path.exists(".autostart.started"): |
|
|
open(".autostart.started","w").close() |
|
|
try: |
|
|
_ = start_training(AUTOSTART_DATASET if AUTOSTART_DATASET else "", AUTOSTART_SHARDS if AUTOSTART_SHARDS else "") |
|
|
_ = read_logs_once() |
|
|
except Exception as e: |
|
|
with open(LOG_FILE,"a") as log: log.write(f"\nβ Autostart failed: {e}\n") |
|
|
|
|
|
app.launch() |