Percy3822 commited on
Commit
7850368
Β·
verified Β·
1 Parent(s): 1f320b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -119
app.py CHANGED
@@ -10,128 +10,148 @@ import glob
10
  import gradio as gr
11
  from transformers import pipeline
12
 
13
- # ---- constants / paths ----
14
  LOG_FILE = "train.log"
 
15
  MODEL_DIR = "trained_model"
16
  ZIP_FILE = "trained_model.zip"
17
- ZIP_TEMP = ZIP_FILE + ".part" # atomic write to avoid half-written zips
18
 
19
- # ---- utils ----
20
  def _human_size(nbytes: int) -> str:
21
- units = ["B", "KB", "MB", "GB", "TB"]
22
- i, x = 0, float(nbytes)
23
- while x >= 1024 and i < len(units) - 1:
24
- x /= 1024.0
25
- i += 1
26
  return f"{x:.1f} {units[i]}"
27
 
28
  def _read_file_safely(path: str, fallback: str):
29
  if os.path.exists(path):
30
  try:
31
- with open(path, "r", encoding="utf-8", errors="ignore") as f:
32
- return f.read()
33
- except:
34
- return fallback
35
  return fallback
36
 
37
  def _zip_folder_atomic(src_dir: str, zip_path: str, tmp_path: str):
38
- if os.path.exists(tmp_path):
39
- os.remove(tmp_path)
40
- with zipfile.ZipFile(tmp_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
41
- for root, _, files in os.walk(src_dir):
42
  for fn in files:
43
- full = os.path.join(root, fn)
44
- arc = os.path.relpath(full, src_dir)
45
- zf.write(full, arcname=arc)
46
- if os.path.exists(zip_path):
47
- os.remove(zip_path)
48
- os.replace(tmp_path, zip_path)
49
 
50
  def _download_info_text() -> str:
51
- if not os.path.exists(ZIP_FILE):
52
- return "No trained model yet."
53
- size = _human_size(os.path.getsize(ZIP_FILE))
54
- mtime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(ZIP_FILE)))
55
  return f"*Model ready:* {ZIP_FILE} \n*Size:* {size} \n*Last modified:* {mtime}"
56
 
57
  def ensure_clean_zip():
58
  for p in (ZIP_FILE, ZIP_TEMP):
59
  if os.path.exists(p):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  try:
61
- os.remove(p)
62
- except:
63
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # ---- training helpers ----
 
 
 
 
 
 
 
 
 
 
66
  def upload_file(file):
67
- """Save uploaded dataset to uploads/ and return stable path."""
68
- if file is None:
69
- return "❌ No file uploaded.", ""
70
  os.makedirs("uploads", exist_ok=True)
71
  dst = os.path.join("uploads", f"dataset_{uuid.uuid4().hex}.jsonl")
72
  shutil.copy(file.name, dst)
73
  return f"βœ… Uploaded: {os.path.basename(file.name)} β†’ {dst}", dst
74
 
75
  def _train_single_file(dataset_path: str, log):
76
- """Call train.py once for a single json/jsonl (or .gz expanded) file."""
77
- proc = subprocess.Popen(
78
- ["python", "train.py", "--dataset", dataset_path, "--output", MODEL_DIR],
79
- stdout=log,
80
- stderr=subprocess.STDOUT,
81
- )
82
- proc.wait()
83
- log.write(f"\n ↳ train.py exited {proc.returncode} for {os.path.basename(dataset_path)}\n")
84
- return proc.returncode == 0
85
 
86
  def _train_worker(dataset_path: str, shards_folder: str):
87
- with open(LOG_FILE, "w") as log:
88
- log.write("πŸ”₯ Starting training (JSON AI)…\n")
89
-
90
- ok = True
91
- with open(LOG_FILE, "a") as log:
92
  if shards_folder:
93
  log.write(f"πŸ“‚ Folder mode: {shards_folder}\n")
94
- paths = sorted(glob.glob(os.path.join(shards_folder, "*.jsonl"))) + \
95
- sorted(glob.glob(os.path.join(shards_folder, "*.json"))) + \
96
- sorted(glob.glob(os.path.join(shards_folder, "*.jsonl.gz"))) + \
97
- sorted(glob.glob(os.path.join(shards_folder, "*.json.gz")))
98
  if not paths:
99
- log.write("❌ No shards found (*.jsonl / *.json / *.jsonl.gz / *.json.gz). Aborting.\n")
100
- ok = False
101
  else:
102
- tmp = "tmp_train.jsonl"
103
- for i, p in enumerate(paths, 1):
104
- log.write(f"\n[{i}/{len(paths)}] Training on shard: {os.path.basename(p)}\n")
105
- if p.endswith(".gz"):
106
  try:
107
- with gzip.open(p, "rt", encoding="utf-8") as rf, open(tmp, "w", encoding="utf-8") as wf:
108
- for line in rf:
109
- wf.write(line)
110
- shard_path = tmp
111
  except Exception as e:
112
- log.write(f"❌ Failed to read gz shard: {e}\n")
113
- ok = False
114
- break
115
  else:
116
- shard_path = p
117
- if not _train_single_file(shard_path, log):
118
- ok = False
119
- break
120
  if os.path.exists(tmp):
121
  try: os.remove(tmp)
122
  except: pass
123
  else:
124
  if not dataset_path or not os.path.exists(dataset_path):
125
- log.write("❌ Please upload a valid dataset first.\n")
126
- ok = False
127
  else:
128
- ok = _train_single_file(dataset_path, log)
129
 
130
  if ok and os.path.isdir(MODEL_DIR):
131
  try:
132
- time.sleep(0.5) # settle delay for FS
133
  _zip_folder_atomic(MODEL_DIR, ZIP_FILE, ZIP_TEMP)
134
- sz = _human_size(os.path.getsize(ZIP_FILE))
135
  log.write(f"\nβœ… Model zipped β†’ {ZIP_FILE} ({sz})\n")
136
  except Exception as e:
137
  log.write(f"\n❌ Zipping failed: {e}\n")
@@ -140,12 +160,11 @@ def _train_worker(dataset_path: str, shards_folder: str):
140
 
141
  def start_training(dataset_path: str, shards_folder: str):
142
  ensure_clean_zip()
143
- t = threading.Thread(target=_train_worker, args=(dataset_path, shards_folder), daemon=True)
144
- t.start()
145
  return "πŸš€ Training started in the background. Use the Refresh buttons to update."
146
 
147
  def read_logs_once():
148
- return _read_file_safely(LOG_FILE, "Waiting for logs...")
149
 
150
  def check_download():
151
  if os.path.exists(ZIP_FILE):
@@ -153,16 +172,13 @@ def check_download():
153
  else:
154
  return gr.update(visible=False, value=None), "No trained model yet."
155
 
156
- # ---- test helpers ----
157
  def upload_test_model_zip(zip_file):
158
- """Upload a model ZIP and extract it to models/test_<uuid>/ for testing."""
159
- if zip_file is None:
160
- return "❌ No file uploaded.", ""
161
  extract_root = os.path.join("models", f"test_{uuid.uuid4().hex}")
162
  os.makedirs(extract_root, exist_ok=True)
163
  try:
164
- with zipfile.ZipFile(zip_file.name, "r") as zf:
165
- zf.extractall(extract_root)
166
  return f"βœ… Model ZIP extracted to: {extract_root}", extract_root
167
  except Exception as e:
168
  return f"❌ Failed to extract: {e}", ""
@@ -171,36 +187,54 @@ def clear_uploaded_model():
171
  return "Model cleared. Will use trained_model/ if available.", ""
172
 
173
  def generate_response(prompt, uploaded_model_path):
174
- if not prompt or not prompt.strip():
175
- return "Please enter a prompt."
176
  try:
177
  if uploaded_model_path and os.path.isdir(uploaded_model_path):
178
- model_path = uploaded_model_path
179
- src = "(uploaded model)"
180
  elif os.path.isdir(MODEL_DIR):
181
- model_path = MODEL_DIR
182
- src = "(trained_model/)"
183
  else:
184
- model_path = "distilgpt2"
185
- src = "(fallback: distilgpt2)"
186
  gen = pipeline("text-generation", model=model_path, tokenizer="distilgpt2")
187
  out = gen(prompt, max_length=256, do_sample=True, temperature=0.7, truncation=True)[0]["generated_text"]
188
  return f"{out}\n\nβ€” using {src}"
189
  except Exception as e:
190
  return f"❌ Error: {e}"
191
 
192
- # ---- UI ----
193
- with gr.Blocks(title="JSON AI Trainer") as app:
194
- gr.Markdown("## 🧩 JSON AI Trainer\nUpload a dataset (JSONL/JSON), train in background, download the model, and test JSON-focused prompts.")
195
 
196
  dataset_state = gr.State(value="")
197
  shard_folder_state = gr.State(value="")
198
  test_model_state = gr.State(value="")
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  with gr.Tab("🧠 Train"):
201
- gr.Markdown("Upload a single JSONL/JSON *or* provide a folder of shards (.jsonl, .json, .jsonl.gz, .json.gz).")
202
  with gr.Row():
203
- file_input = gr.File(label="Upload single dataset file", file_types=[".jsonl", ".json"])
204
  upload_btn = gr.Button("πŸ“€ Upload (single file)")
205
  with gr.Row():
206
  shards_folder = gr.Textbox(value="", label="Folder with shards (optional)")
@@ -219,15 +253,9 @@ with gr.Blocks(title="JSON AI Trainer") as app:
219
  download_btn = gr.DownloadButton(label="πŸ“₯ Download Trained Model (.zip)", visible=False, value=None)
220
 
221
  upload_btn.click(fn=upload_file, inputs=file_input, outputs=[status_box, dataset_state])
222
- use_folder_btn.click(
223
- fn=lambda p: ("βœ… Using folder for training." if p.strip() else "❌ Provide a valid folder path.", p.strip()),
224
- inputs=shards_folder,
225
- outputs=[status_box, shard_folder_state]
226
- )
227
- start_btn.click(
228
- fn=start_training,
229
- inputs=[dataset_state, shard_folder_state],
230
- outputs=status_box
231
  ).then(fn=read_logs_once, outputs=log_output
232
  ).then(fn=check_download, outputs=[download_btn, download_info])
233
 
@@ -235,17 +263,13 @@ with gr.Blocks(title="JSON AI Trainer") as app:
235
  refresh_dl_btn.click(fn=check_download, outputs=[download_btn, download_info])
236
 
237
  with gr.Tab("πŸš€ Test"):
238
- gr.Markdown("Upload a model ZIP or use the just-trained model. Try prompts like β€œFix this JSON”, β€œGenerate JSON schema for …”")
239
  with gr.Row():
240
  test_zip = gr.File(label="Upload Model ZIP", file_types=[".zip"])
241
  load_test_btn = gr.Button("πŸ“¦ Load Uploaded Model ZIP")
242
  clear_test_btn = gr.Button("🧹 Clear Uploaded Model")
243
  test_status = gr.Textbox(label="Test Model Status", interactive=False)
244
-
245
- prompt_input = gr.Textbox(
246
- label="Prompt",
247
- placeholder='e.g., "Generate valid JSON for a product with id, name, price, tags (array of strings)"'
248
- )
249
  test_btn = gr.Button("πŸ” Generate")
250
  response_output = gr.Textbox(label="AI Response", lines=12)
251
 
@@ -253,18 +277,16 @@ with gr.Blocks(title="JSON AI Trainer") as app:
253
  clear_test_btn.click(fn=clear_uploaded_model, outputs=[test_status, test_model_state])
254
  test_btn.click(fn=generate_response, inputs=[prompt_input, test_model_state], outputs=response_output)
255
 
256
- # Optional auto-start via env vars
257
- AUTOSTART = os.getenv("AUTOSTART_TRAIN", "0") == "1"
258
- AUTOSTART_DATASET = os.getenv("AUTOSTART_DATASET", "").strip()
259
- AUTOSTART_SHARDS = os.getenv("AUTOSTART_SHARDS", "").strip()
260
  if AUTOSTART and not os.path.exists(".autostart.started"):
261
- open(".autostart.started", "w").close()
262
  try:
263
- _ = start_training(AUTOSTART_DATASET if AUTOSTART_DATASET else "",
264
- AUTOSTART_SHARDS if AUTOSTART_SHARDS else "")
265
  _ = read_logs_once()
266
  except Exception as e:
267
- with open(LOG_FILE, "a") as log:
268
- log.write(f"\n❌ Autostart failed: {e}\n")
269
 
270
  app.launch()
 
10
  import gradio as gr
11
  from transformers import pipeline
12
 
 
13
  LOG_FILE = "train.log"
14
+ GEN_LOG_FILE = "dataset_gen.log"
15
  MODEL_DIR = "trained_model"
16
  ZIP_FILE = "trained_model.zip"
17
+ ZIP_TEMP = ZIP_FILE + ".part"
18
 
 
19
  def _human_size(nbytes: int) -> str:
20
+ units = ["B","KB","MB","GB","TB"]; i=0; x=float(nbytes)
21
+ while x>=1024 and i<len(units)-1: x/=1024.0; i+=1
 
 
 
22
  return f"{x:.1f} {units[i]}"
23
 
24
  def _read_file_safely(path: str, fallback: str):
25
  if os.path.exists(path):
26
  try:
27
+ with open(path,"r",encoding="utf-8",errors="ignore") as f: return f.read()
28
+ except: return fallback
 
 
29
  return fallback
30
 
31
  def _zip_folder_atomic(src_dir: str, zip_path: str, tmp_path: str):
32
+ if os.path.exists(tmp_path): os.remove(tmp_path)
33
+ with zipfile.ZipFile(tmp_path,"w",compression=zipfile.ZIP_DEFLATED) as zf:
34
+ for root,_,files in os.walk(src_dir):
 
35
  for fn in files:
36
+ full=os.path.join(root,fn); arc=os.path.relpath(full,src_dir)
37
+ zf.write(full,arcname=arc)
38
+ if os.path.exists(zip_path): os.remove(zip_path)
39
+ os.replace(tmp_path,zip_path)
 
 
40
 
41
  def _download_info_text() -> str:
42
+ if not os.path.exists(ZIP_FILE): return "No trained model yet."
43
+ size=_human_size(os.path.getsize(ZIP_FILE))
44
+ mtime=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(ZIP_FILE)))
 
45
  return f"*Model ready:* {ZIP_FILE} \n*Size:* {size} \n*Last modified:* {mtime}"
46
 
47
  def ensure_clean_zip():
48
  for p in (ZIP_FILE, ZIP_TEMP):
49
  if os.path.exists(p):
50
+ try: os.remove(p)
51
+ except: pass
52
+
53
+ # --------- Dataset Generator ----------
54
+ def start_generation(total, shard_size, out_dir, prefix):
55
+ total=int(total or 1_000_000)
56
+ shard_size=int(shard_size or 10_000)
57
+ out_dir=(out_dir or "json_dataset_v1").strip()
58
+ prefix=(prefix or "json").strip()
59
+ with open(GEN_LOG_FILE,"w") as log:
60
+ log.write(f"🚧 Generating dataset: total={total}, shard_size={shard_size}, out_dir={out_dir}, prefix={prefix}\n")
61
+ def _worker():
62
+ with open(GEN_LOG_FILE,"a") as log:
63
+ if not os.path.exists("make_json_dataset.py"):
64
+ log.write("❌ make_json_dataset.py not found.\n"); return
65
  try:
66
+ p = subprocess.Popen(
67
+ ["python","make_json_dataset.py",
68
+ "--total",str(total),
69
+ "--shard_size",str(shard_size),
70
+ "--out_dir",out_dir,
71
+ "--prefix",prefix],
72
+ stdout=log, stderr=subprocess.STDOUT
73
+ )
74
+ p.wait()
75
+ log.write(f"\nπŸ”š Generator exited with code {p.returncode}\n")
76
+ if p.returncode==0:
77
+ files = sorted(glob.glob(os.path.join(out_dir,"*.jsonl.gz")))
78
+ log.write(f"βœ… Done. Shards: {len(files)} in {out_dir}\n")
79
+ else:
80
+ log.write("❌ Generation failed.\n")
81
+ except Exception as e:
82
+ log.write(f"\n❌ Exception: {e}\n")
83
+ threading.Thread(target=_worker, daemon=True).start()
84
+ return f"πŸš€ Dataset generation started. Output folder: {out_dir}"
85
+
86
+ def read_gen_logs():
87
+ return _read_file_safely(GEN_LOG_FILE,"Waiting for generator logs...")
88
 
89
+ def list_shards(folder):
90
+ if not folder or not os.path.isdir(folder): return "❌ Provide a valid folder path."
91
+ jsonl = sorted(glob.glob(os.path.join(folder,"*.jsonl")))
92
+ gz = sorted(glob.glob(os.path.join(folder,"*.jsonl.gz")))
93
+ total = len(jsonl)+len(gz)
94
+ if total==0: return "No shards found."
95
+ preview=(jsonl+gz)[:10]
96
+ lines=[f"Found {total} shard(s). Showing first {len(preview)}:"]+[f"- {os.path.basename(p)}" for p in preview]
97
+ return "\n".join(lines)
98
+
99
+ # --------- Training ----------
100
  def upload_file(file):
101
+ if file is None: return "❌ No file uploaded.", ""
 
 
102
  os.makedirs("uploads", exist_ok=True)
103
  dst = os.path.join("uploads", f"dataset_{uuid.uuid4().hex}.jsonl")
104
  shutil.copy(file.name, dst)
105
  return f"βœ… Uploaded: {os.path.basename(file.name)} β†’ {dst}", dst
106
 
107
  def _train_single_file(dataset_path: str, log):
108
+ p = subprocess.Popen(["python","train.py","--dataset",dataset_path,"--output",MODEL_DIR],
109
+ stdout=log, stderr=subprocess.STDOUT)
110
+ p.wait()
111
+ log.write(f"\n ↳ train.py exited {p.returncode} for {os.path.basename(dataset_path)}\n")
112
+ return p.returncode==0
 
 
 
 
113
 
114
  def _train_worker(dataset_path: str, shards_folder: str):
115
+ with open(LOG_FILE,"w") as log: log.write("πŸ”₯ Starting training (JSON AI)…\n")
116
+ ok=True
117
+ with open(LOG_FILE,"a") as log:
 
 
118
  if shards_folder:
119
  log.write(f"πŸ“‚ Folder mode: {shards_folder}\n")
120
+ paths = sorted(glob.glob(os.path.join(shards_folder,"*.jsonl"))) + \
121
+ sorted(glob.glob(os.path.join(shards_folder,"*.json"))) + \
122
+ sorted(glob.glob(os.path.join(shards_folder,"*.jsonl.gz"))) + \
123
+ sorted(glob.glob(os.path.join(shards_folder,"*.json.gz")))
124
  if not paths:
125
+ log.write("❌ No shards found. Aborting.\n"); ok=False
 
126
  else:
127
+ tmp="tmp_train.jsonl"
128
+ for i,pth in enumerate(paths,1):
129
+ log.write(f"\n[{i}/{len(paths)}] Training on shard: {os.path.basename(pth)}\n")
130
+ if pth.endswith(".gz"):
131
  try:
132
+ with gzip.open(pth,"rt",encoding="utf-8") as rf, open(tmp,"w",encoding="utf-8") as wf:
133
+ for line in rf: wf.write(line)
134
+ shard=tmp
 
135
  except Exception as e:
136
+ log.write(f"❌ Failed to read gz shard: {e}\n"); ok=False; break
 
 
137
  else:
138
+ shard=pth
139
+ if not _train_single_file(shard, log):
140
+ ok=False; break
 
141
  if os.path.exists(tmp):
142
  try: os.remove(tmp)
143
  except: pass
144
  else:
145
  if not dataset_path or not os.path.exists(dataset_path):
146
+ log.write("❌ Please upload a valid dataset first.\n"); ok=False
 
147
  else:
148
+ ok=_train_single_file(dataset_path, log)
149
 
150
  if ok and os.path.isdir(MODEL_DIR):
151
  try:
152
+ time.sleep(0.5)
153
  _zip_folder_atomic(MODEL_DIR, ZIP_FILE, ZIP_TEMP)
154
+ sz=_human_size(os.path.getsize(ZIP_FILE))
155
  log.write(f"\nβœ… Model zipped β†’ {ZIP_FILE} ({sz})\n")
156
  except Exception as e:
157
  log.write(f"\n❌ Zipping failed: {e}\n")
 
160
 
161
  def start_training(dataset_path: str, shards_folder: str):
162
  ensure_clean_zip()
163
+ threading.Thread(target=_train_worker, args=(dataset_path, shards_folder), daemon=True).start()
 
164
  return "πŸš€ Training started in the background. Use the Refresh buttons to update."
165
 
166
  def read_logs_once():
167
+ return _read_file_safely(LOG_FILE,"Waiting for logs...")
168
 
169
  def check_download():
170
  if os.path.exists(ZIP_FILE):
 
172
  else:
173
  return gr.update(visible=False, value=None), "No trained model yet."
174
 
175
+ # --------- Test ----------
176
  def upload_test_model_zip(zip_file):
177
+ if zip_file is None: return "❌ No file uploaded.", ""
 
 
178
  extract_root = os.path.join("models", f"test_{uuid.uuid4().hex}")
179
  os.makedirs(extract_root, exist_ok=True)
180
  try:
181
+ with zipfile.ZipFile(zip_file.name,"r") as zf: zf.extractall(extract_root)
 
182
  return f"βœ… Model ZIP extracted to: {extract_root}", extract_root
183
  except Exception as e:
184
  return f"❌ Failed to extract: {e}", ""
 
187
  return "Model cleared. Will use trained_model/ if available.", ""
188
 
189
  def generate_response(prompt, uploaded_model_path):
190
+ if not prompt or not prompt.strip(): return "Please enter a prompt."
 
191
  try:
192
  if uploaded_model_path and os.path.isdir(uploaded_model_path):
193
+ model_path, src = uploaded_model_path, "(uploaded model)"
 
194
  elif os.path.isdir(MODEL_DIR):
195
+ model_path, src = MODEL_DIR, "(trained_model/)"
 
196
  else:
197
+ model_path, src = "distilgpt2", "(fallback: distilgpt2)"
 
198
  gen = pipeline("text-generation", model=model_path, tokenizer="distilgpt2")
199
  out = gen(prompt, max_length=256, do_sample=True, temperature=0.7, truncation=True)[0]["generated_text"]
200
  return f"{out}\n\nβ€” using {src}"
201
  except Exception as e:
202
  return f"❌ Error: {e}"
203
 
204
+ # --------- UI ----------
205
+ with gr.Blocks(title="JSON AI Trainer (with Dataset Generator)") as app:
206
+ gr.Markdown("## 🧩 JSON AI Trainer\nGenerate a large JSON dataset, train (single file or folder of shards), download the model, and test.")
207
 
208
  dataset_state = gr.State(value="")
209
  shard_folder_state = gr.State(value="")
210
  test_model_state = gr.State(value="")
211
 
212
+ with gr.Tab("πŸ§ͺ Generate Dataset"):
213
+ with gr.Row():
214
+ total_in = gr.Number(value=1_000_000, label="Total samples")
215
+ shard_in = gr.Number(value=10_000, label="Rows per shard")
216
+ with gr.Row():
217
+ out_dir_in = gr.Textbox(value="json_dataset_v1", label="Output folder")
218
+ prefix_in = gr.Textbox(value="json", label="File prefix")
219
+ with gr.Row():
220
+ gen_btn = gr.Button("πŸš€ Start Generation")
221
+ gen_refresh_btn = gr.Button("πŸ” Refresh Logs")
222
+ gen_status = gr.Textbox(label="Generator Status", interactive=False)
223
+ gen_logs = gr.Textbox(label="Generator Logs", lines=16)
224
+ with gr.Row():
225
+ list_folder = gr.Textbox(value="json_dataset_v1", label="Preview shards in folder")
226
+ list_btn = gr.Button("πŸ‘€ List Shards")
227
+ list_out = gr.Textbox(label="Shard Preview", lines=8)
228
+
229
+ gen_btn.click(fn=start_generation, inputs=[total_in, shard_in, out_dir_in, prefix_in], outputs=gen_status
230
+ ).then(fn=read_gen_logs, outputs=gen_logs)
231
+ gen_refresh_btn.click(fn=read_gen_logs, outputs=gen_logs)
232
+ list_btn.click(fn=list_shards, inputs=list_folder, outputs=list_out)
233
+
234
  with gr.Tab("🧠 Train"):
235
+ gr.Markdown("Upload a single JSON/JSONL file *or* train on a folder of shards (.json, .jsonl, .jsonl.gz, .json.gz).")
236
  with gr.Row():
237
+ file_input = gr.File(label="Upload single dataset file", file_types=[".json",".jsonl"])
238
  upload_btn = gr.Button("πŸ“€ Upload (single file)")
239
  with gr.Row():
240
  shards_folder = gr.Textbox(value="", label="Folder with shards (optional)")
 
253
  download_btn = gr.DownloadButton(label="πŸ“₯ Download Trained Model (.zip)", visible=False, value=None)
254
 
255
  upload_btn.click(fn=upload_file, inputs=file_input, outputs=[status_box, dataset_state])
256
+ use_folder_btn.click(fn=lambda p: ("βœ… Using folder for training." if p.strip() else "❌ Provide a valid folder path.", p.strip()),
257
+ inputs=shards_folder, outputs=[status_box, shard_folder_state])
258
+ start_btn.click(fn=start_training, inputs=[dataset_state, shard_folder_state], outputs=status_box
 
 
 
 
 
 
259
  ).then(fn=read_logs_once, outputs=log_output
260
  ).then(fn=check_download, outputs=[download_btn, download_info])
261
 
 
263
  refresh_dl_btn.click(fn=check_download, outputs=[download_btn, download_info])
264
 
265
  with gr.Tab("πŸš€ Test"):
266
+ gr.Markdown("Upload a model ZIP or use the just-trained model.")
267
  with gr.Row():
268
  test_zip = gr.File(label="Upload Model ZIP", file_types=[".zip"])
269
  load_test_btn = gr.Button("πŸ“¦ Load Uploaded Model ZIP")
270
  clear_test_btn = gr.Button("🧹 Clear Uploaded Model")
271
  test_status = gr.Textbox(label="Test Model Status", interactive=False)
272
+ prompt_input = gr.Textbox(label="Prompt", placeholder='e.g., "Generate JSON Schema for an invoice" or "Fix this JSON: {\'a\':1,}"')
 
 
 
 
273
  test_btn = gr.Button("πŸ” Generate")
274
  response_output = gr.Textbox(label="AI Response", lines=12)
275
 
 
277
  clear_test_btn.click(fn=clear_uploaded_model, outputs=[test_status, test_model_state])
278
  test_btn.click(fn=generate_response, inputs=[prompt_input, test_model_state], outputs=response_output)
279
 
280
+ # Optional: autostart on boot via Space variables
281
+ AUTOSTART = os.getenv("AUTOSTART_TRAIN","0") == "1"
282
+ AUTOSTART_DATASET = os.getenv("AUTOSTART_DATASET","").strip()
283
+ AUTOSTART_SHARDS = os.getenv("AUTOSTART_SHARDS","").strip()
284
  if AUTOSTART and not os.path.exists(".autostart.started"):
285
+ open(".autostart.started","w").close()
286
  try:
287
+ _ = start_training(AUTOSTART_DATASET if AUTOSTART_DATASET else "", AUTOSTART_SHARDS if AUTOSTART_SHARDS else "")
 
288
  _ = read_logs_once()
289
  except Exception as e:
290
+ with open(LOG_FILE,"a") as log: log.write(f"\n❌ Autostart failed: {e}\n")
 
291
 
292
  app.launch()