json_ai / app.py
Percy3822's picture
Update app.py
bf43903 verified
import os
import shutil
import subprocess
import threading
import uuid
import time
import zipfile
import gzip
import glob
import gradio as gr
from transformers import pipeline
LOG_FILE = "train.log"
GEN_LOG_FILE = "dataset_gen.log"
MODEL_DIR = "trained_model"
ZIP_FILE = "trained_model.zip"
ZIP_TEMP = ZIP_FILE + ".part"
def _human_size(nbytes: int) -> str:
units = ["B","KB","MB","GB","TB"]; i=0; x=float(nbytes)
while x>=1024 and i<len(units)-1: x/=1024.0; i+=1
return f"{x:.1f} {units[i]}"
def _read_file_safely(path: str, fallback: str):
if os.path.exists(path):
try:
with open(path,"r",encoding="utf-8",errors="ignore") as f: return f.read()
except: return fallback
return fallback
def _zip_folder_atomic(src_dir: str, zip_path: str, tmp_path: str):
if os.path.exists(tmp_path): os.remove(tmp_path)
with zipfile.ZipFile(tmp_path,"w",compression=zipfile.ZIP_DEFLATED) as zf:
for root,_,files in os.walk(src_dir):
for fn in files:
full=os.path.join(root,fn); arc=os.path.relpath(full,src_dir)
zf.write(full,arcname=arc)
if os.path.exists(zip_path): os.remove(zip_path)
os.replace(tmp_path,zip_path)
def _download_info_text() -> str:
if not os.path.exists(ZIP_FILE): return "No trained model yet."
size=_human_size(os.path.getsize(ZIP_FILE))
mtime=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(ZIP_FILE)))
return f"*Model ready:* {ZIP_FILE} \n*Size:* {size} \n*Last modified:* {mtime}"
def ensure_clean_zip():
for p in (ZIP_FILE, ZIP_TEMP):
if os.path.exists(p):
try: os.remove(p)
except: pass
# --------- Dataset Generator ----------
def start_generation(total, shard_size, out_dir, prefix):
total=int(total or 1_000_000)
shard_size=int(shard_size or 10_000)
out_dir=(out_dir or "json_dataset_v1").strip()
prefix=(prefix or "json").strip()
with open(GEN_LOG_FILE,"w") as log:
log.write(f"🚧 Generating dataset: total={total}, shard_size={shard_size}, out_dir={out_dir}, prefix={prefix}\n")
def _worker():
with open(GEN_LOG_FILE,"a") as log:
if not os.path.exists("make_json_dataset.py"):
log.write("❌ make_json_dataset.py not found.\n"); return
try:
p = subprocess.Popen(
["python","make_json_dataset.py",
"--total",str(total),
"--shard_size",str(shard_size),
"--out_dir",out_dir,
"--prefix",prefix],
stdout=log, stderr=subprocess.STDOUT
)
p.wait()
log.write(f"\nπŸ”š Generator exited with code {p.returncode}\n")
if p.returncode==0:
files = sorted(glob.glob(os.path.join(out_dir,"*.jsonl.gz")))
log.write(f"βœ… Done. Shards: {len(files)} in {out_dir}\n")
else:
log.write("❌ Generation failed.\n")
except Exception as e:
log.write(f"\n❌ Exception: {e}\n")
threading.Thread(target=_worker, daemon=True).start()
return f"πŸš€ Dataset generation started. Output folder: {out_dir}"
def read_gen_logs():
return _read_file_safely(GEN_LOG_FILE,"Waiting for generator logs...")
def list_shards(folder):
if not folder or not os.path.isdir(folder): return "❌ Provide a valid folder path."
# ⬇ Only JSONL shards; ignore manifest files
jsonl = sorted(glob.glob(os.path.join(folder,"*.jsonl")))
gz = sorted(glob.glob(os.path.join(folder,"*.jsonl.gz")))
files = [p for p in (jsonl+gz) if "manifest" not in os.path.basename(p).lower()]
total = len(files)
if total==0: return "No shards found (*.jsonl / *.jsonl.gz)."
preview=files[:10]
lines=[f"Found {total} shard(s). Showing first {len(preview)}:"]+[f"- {os.path.basename(p)}" for p in preview]
return "\n".join(lines)
# --------- Training ----------
def upload_file(file):
if file is None: return "❌ No file uploaded.", ""
os.makedirs("uploads", exist_ok=True)
dst = os.path.join("uploads", f"dataset_{uuid.uuid4().hex}.jsonl")
shutil.copy(file.name, dst)
return f"βœ… Uploaded: {os.path.basename(file.name)} β†’ {dst}", dst
def _train_single_file(dataset_path: str, log):
p = subprocess.Popen(["python","train.py","--dataset",dataset_path,"--output",MODEL_DIR],
stdout=log, stderr=subprocess.STDOUT)
p.wait()
log.write(f"\n ↳ train.py exited {p.returncode} for {os.path.basename(dataset_path)}\n")
return p.returncode==0
def _train_worker(dataset_path: str, shards_folder: str):
with open(LOG_FILE,"w") as log: log.write("πŸ”₯ Starting training (JSON AI)…\n")
ok=True
with open(LOG_FILE,"a") as log:
if shards_folder:
log.write(f"πŸ“‚ Folder mode: {shards_folder}\n")
# ⬇ Only JSONL shards; ignore manifest files
paths = sorted(glob.glob(os.path.join(shards_folder,"*.jsonl"))) + \
sorted(glob.glob(os.path.join(shards_folder,"*.jsonl.gz")))
paths = [p for p in paths if "manifest" not in os.path.basename(p).lower()]
if not paths:
log.write("❌ No shards found (*.jsonl / *.jsonl.gz). Aborting.\n"); ok=False
else:
tmp="tmp_train.jsonl"
for i,pth in enumerate(paths,1):
log.write(f"\n[{i}/{len(paths)}] Training on shard: {os.path.basename(pth)}\n")
if pth.endswith(".gz"):
try:
with gzip.open(pth,"rt",encoding="utf-8") as rf, open(tmp,"w",encoding="utf-8") as wf:
for line in rf: wf.write(line)
shard=tmp
except Exception as e:
log.write(f"❌ Failed to read gz shard: {e}\n"); ok=False; break
else:
shard=pth
if not _train_single_file(shard, log):
ok=False; break
if os.path.exists(tmp):
try: os.remove(tmp)
except: pass
else:
if not dataset_path or not os.path.exists(dataset_path):
log.write("❌ Please upload a valid dataset first.\n"); ok=False
else:
ok=_train_single_file(dataset_path, log)
if ok and os.path.isdir(MODEL_DIR):
try:
time.sleep(0.5)
_zip_folder_atomic(MODEL_DIR, ZIP_FILE, ZIP_TEMP)
sz=_human_size(os.path.getsize(ZIP_FILE))
log.write(f"\nβœ… Model zipped β†’ {ZIP_FILE} ({sz})\n")
except Exception as e:
log.write(f"\n❌ Zipping failed: {e}\n")
else:
log.write("\n❌ Training failed; no zip created.\n")
def start_training(dataset_path: str, shards_folder: str):
ensure_clean_zip()
threading.Thread(target=_train_worker, args=(dataset_path, shards_folder), daemon=True).start()
return "πŸš€ Training started in the background. Use the Refresh buttons to update."
def read_logs_once():
return _read_file_safely(LOG_FILE,"Waiting for logs...")
def check_download():
if os.path.exists(ZIP_FILE):
return gr.update(visible=True, value=ZIP_FILE), _download_info_text()
else:
return gr.update(visible=False, value=None), "No trained model yet."
# --------- Test ----------
def upload_test_model_zip(zip_file):
if zip_file is None: return "❌ No file uploaded.", ""
extract_root = os.path.join("models", f"test_{uuid.uuid4().hex}")
os.makedirs(extract_root, exist_ok=True)
try:
with zipfile.ZipFile(zip_file.name,"r") as zf: zf.extractall(extract_root)
return f"βœ… Model ZIP extracted to: {extract_root}", extract_root
except Exception as e:
return f"❌ Failed to extract: {e}", ""
def clear_uploaded_model():
return "Model cleared. Will use trained_model/ if available.", ""
def generate_response(prompt, uploaded_model_path):
if not prompt or not prompt.strip(): return "Please enter a prompt."
try:
if uploaded_model_path and os.path.isdir(uploaded_model_path):
model_path, src = uploaded_model_path, "(uploaded model)"
elif os.path.isdir(MODEL_DIR):
model_path, src = MODEL_DIR, "(trained_model/)"
else:
model_path, src = "distilgpt2", "(fallback: distilgpt2)"
gen = pipeline("text-generation", model=model_path, tokenizer="distilgpt2")
out = gen(prompt, max_length=256, do_sample=True, temperature=0.7, truncation=True)[0]["generated_text"]
return f"{out}\n\nβ€” using {src}"
except Exception as e:
return f"❌ Error: {e}"
# --------- UI ----------
with gr.Blocks(title="JSON AI Trainer (with Dataset Generator)") as app:
gr.Markdown("## 🧩 JSON AI Trainer\nGenerate a large JSON dataset, train (single file or folder of shards), download the model, and test.")
dataset_state = gr.State(value="")
shard_folder_state = gr.State(value="")
test_model_state = gr.State(value="")
with gr.Tab("πŸ§ͺ Generate Dataset"):
with gr.Row():
total_in = gr.Number(value=1_000_000, label="Total samples")
shard_in = gr.Number(value=10_000, label="Rows per shard")
with gr.Row():
out_dir_in = gr.Textbox(value="json_dataset_v1", label="Output folder")
prefix_in = gr.Textbox(value="json", label="File prefix")
with gr.Row():
gen_btn = gr.Button("πŸš€ Start Generation")
gen_refresh_btn = gr.Button("πŸ” Refresh Logs")
gen_status = gr.Textbox(label="Generator Status", interactive=False)
gen_logs = gr.Textbox(label="Generator Logs", lines=16)
with gr.Row():
list_folder = gr.Textbox(value="json_dataset_v1", label="Preview shards in folder")
list_btn = gr.Button("πŸ‘€ List Shards")
list_out = gr.Textbox(label="Shard Preview", lines=8)
gen_btn.click(fn=start_generation, inputs=[total_in, shard_in, out_dir_in, prefix_in], outputs=gen_status
).then(fn=read_gen_logs, outputs=gen_logs)
gen_refresh_btn.click(fn=read_gen_logs, outputs=gen_logs)
list_btn.click(fn=list_shards, inputs=list_folder, outputs=list_out)
with gr.Tab("🧠 Train"):
gr.Markdown("Upload a single JSON/JSONL file *or* train on a folder of shards (.jsonl, .jsonl.gz). Manifests are ignored.")
with gr.Row():
file_input = gr.File(label="Upload single dataset file", file_types=[".json",".jsonl"])
upload_btn = gr.Button("πŸ“€ Upload (single file)")
with gr.Row():
shards_folder = gr.Textbox(value="", label="Folder with shards (optional)")
use_folder_btn = gr.Button("πŸ“‚ Use Folder For Training")
status_box = gr.Textbox(label="Status", interactive=False)
with gr.Row():
start_btn = gr.Button("πŸš€ Start Training")
refresh_btn = gr.Button("πŸ” Refresh Logs")
refresh_dl_btn = gr.Button("πŸ“¦ Refresh Download Area")
log_output = gr.Textbox(label="πŸ“œ Training Logs", lines=18)
with gr.Group():
gr.Markdown("### πŸ“¦ Trained Model")
download_info = gr.Markdown(value="No trained model yet.")
download_btn = gr.DownloadButton(label="πŸ“₯ Download Trained Model (.zip)", visible=False, value=None)
upload_btn.click(fn=upload_file, inputs=file_input, outputs=[status_box, dataset_state])
use_folder_btn.click(fn=lambda p: ("βœ… Using folder for training." if p.strip() else "❌ Provide a valid folder path.", p.strip()),
inputs=shards_folder, outputs=[status_box, shard_folder_state])
start_btn.click(fn=start_training, inputs=[dataset_state, shard_folder_state], outputs=status_box
).then(fn=read_logs_once, outputs=log_output
).then(fn=check_download, outputs=[download_btn, download_info])
refresh_btn.click(fn=read_logs_once, outputs=log_output)
refresh_dl_btn.click(fn=check_download, outputs=[download_btn, download_info])
with gr.Tab("πŸš€ Test"):
gr.Markdown("Upload a model ZIP or use the just-trained model.")
with gr.Row():
test_zip = gr.File(label="Upload Model ZIP", file_types=[".zip"])
load_test_btn = gr.Button("πŸ“¦ Load Uploaded Model ZIP")
clear_test_btn = gr.Button("🧹 Clear Uploaded Model")
test_status = gr.Textbox(label="Test Model Status", interactive=False)
prompt_input = gr.Textbox(label="Prompt", placeholder='e.g., "Generate JSON Schema for an invoice" or "Fix this JSON: {\'a\':1,}"')
test_btn = gr.Button("πŸ” Generate")
response_output = gr.Textbox(label="AI Response", lines=12)
load_test_btn.click(fn=upload_test_model_zip, inputs=test_zip, outputs=[test_status, test_model_state])
clear_test_btn.click(fn=clear_uploaded_model, outputs=[test_status, test_model_state])
test_btn.click(fn=generate_response, inputs=[prompt_input, test_model_state], outputs=response_output)
# Optional: autostart on boot via Space variables
AUTOSTART = os.getenv("AUTOSTART_TRAIN","0") == "1"
AUTOSTART_DATASET = os.getenv("AUTOSTART_DATASET","").strip()
AUTOSTART_SHARDS = os.getenv("AUTOSTART_SHARDS","").strip()
if AUTOSTART and not os.path.exists(".autostart.started"):
open(".autostart.started","w").close()
try:
_ = start_training(AUTOSTART_DATASET if AUTOSTART_DATASET else "", AUTOSTART_SHARDS if AUTOSTART_SHARDS else "")
_ = read_logs_once()
except Exception as e:
with open(LOG_FILE,"a") as log: log.write(f"\n❌ Autostart failed: {e}\n")
app.launch()