Percy3822 commited on
Commit
10bd12a
Β·
verified Β·
1 Parent(s): 6c1e8b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import threading
5
+ import uuid
6
+ import time
7
+ import zipfile
8
+ import gzip
9
+ import glob
10
+ import gradio as gr
11
+ from transformers import pipeline
12
+
13
+ # ---- constants / paths ----
14
+ LOG_FILE = "train.log"
15
+ MODEL_DIR = "trained_model"
16
+ ZIP_FILE = "trained_model.zip"
17
+ ZIP_TEMP = ZIP_FILE + ".part" # atomic write to avoid half-written zips
18
+
19
+ # ---- utils ----
20
+ def _human_size(nbytes: int) -> str:
21
+ units = ["B", "KB", "MB", "GB", "TB"]
22
+ i, x = 0, float(nbytes)
23
+ while x >= 1024 and i < len(units) - 1:
24
+ x /= 1024.0
25
+ i += 1
26
+ return f"{x:.1f} {units[i]}"
27
+
28
+ def _read_file_safely(path: str, fallback: str):
29
+ if os.path.exists(path):
30
+ try:
31
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
32
+ return f.read()
33
+ except:
34
+ return fallback
35
+ return fallback
36
+
37
+ def _zip_folder_atomic(src_dir: str, zip_path: str, tmp_path: str):
38
+ if os.path.exists(tmp_path):
39
+ os.remove(tmp_path)
40
+ with zipfile.ZipFile(tmp_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
41
+ for root, _, files in os.walk(src_dir):
42
+ for fn in files:
43
+ full = os.path.join(root, fn)
44
+ arc = os.path.relpath(full, src_dir)
45
+ zf.write(full, arcname=arc)
46
+ if os.path.exists(zip_path):
47
+ os.remove(zip_path)
48
+ os.replace(tmp_path, zip_path)
49
+
50
+ def _download_info_text() -> str:
51
+ if not os.path.exists(ZIP_FILE):
52
+ return "No trained model yet."
53
+ size = _human_size(os.path.getsize(ZIP_FILE))
54
+ mtime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(ZIP_FILE)))
55
+ return f"*Model ready:* {ZIP_FILE} \n*Size:* {size} \n*Last modified:* {mtime}"
56
+
57
+ def ensure_clean_zip():
58
+ for p in (ZIP_FILE, ZIP_TEMP):
59
+ if os.path.exists(p):
60
+ try:
61
+ os.remove(p)
62
+ except:
63
+ pass
64
+
65
+ # ---- training helpers ----
66
+ def upload_file(file):
67
+ """Save uploaded dataset to uploads/ and return stable path."""
68
+ if file is None:
69
+ return "❌ No file uploaded.", ""
70
+ os.makedirs("uploads", exist_ok=True)
71
+ dst = os.path.join("uploads", f"dataset_{uuid.uuid4().hex}.jsonl")
72
+ shutil.copy(file.name, dst)
73
+ return f"βœ… Uploaded: {os.path.basename(file.name)} β†’ {dst}", dst
74
+
75
+ def _train_single_file(dataset_path: str, log):
76
+ """Call train.py once for a single json/jsonl (or .gz expanded) file."""
77
+ proc = subprocess.Popen(
78
+ ["python", "train.py", "--dataset", dataset_path, "--output", MODEL_DIR],
79
+ stdout=log,
80
+ stderr=subprocess.STDOUT,
81
+ )
82
+ proc.wait()
83
+ log.write(f"\n ↳ train.py exited {proc.returncode} for {os.path.basename(dataset_path)}\n")
84
+ return proc.returncode == 0
85
+
86
+ def _train_worker(dataset_path: str, shards_folder: str):
87
+ with open(LOG_FILE, "w") as log:
88
+ log.write("πŸ”₯ Starting training (JSON AI)…\n")
89
+
90
+ ok = True
91
+ with open(LOG_FILE, "a") as log:
92
+ if shards_folder:
93
+ log.write(f"πŸ“‚ Folder mode: {shards_folder}\n")
94
+ paths = sorted(glob.glob(os.path.join(shards_folder, "*.jsonl"))) + \
95
+ sorted(glob.glob(os.path.join(shards_folder, "*.json"))) + \
96
+ sorted(glob.glob(os.path.join(shards_folder, "*.jsonl.gz"))) + \
97
+ sorted(glob.glob(os.path.join(shards_folder, "*.json.gz")))
98
+ if not paths:
99
+ log.write("❌ No shards found (*.jsonl / *.json / *.jsonl.gz / *.json.gz). Aborting.\n")
100
+ ok = False
101
+ else:
102
+ tmp = "tmp_train.jsonl"
103
+ for i, p in enumerate(paths, 1):
104
+ log.write(f"\n[{i}/{len(paths)}] Training on shard: {os.path.basename(p)}\n")
105
+ if p.endswith(".gz"):
106
+ try:
107
+ with gzip.open(p, "rt", encoding="utf-8") as rf, open(tmp, "w", encoding="utf-8") as wf:
108
+ for line in rf:
109
+ wf.write(line)
110
+ shard_path = tmp
111
+ except Exception as e:
112
+ log.write(f"❌ Failed to read gz shard: {e}\n")
113
+ ok = False
114
+ break
115
+ else:
116
+ shard_path = p
117
+ if not _train_single_file(shard_path, log):
118
+ ok = False
119
+ break
120
+ if os.path.exists(tmp):
121
+ try: os.remove(tmp)
122
+ except: pass
123
+ else:
124
+ if not dataset_path or not os.path.exists(dataset_path):
125
+ log.write("❌ Please upload a valid dataset first.\n")
126
+ ok = False
127
+ else:
128
+ ok = _train_single_file(dataset_path, log)
129
+
130
+ if ok and os.path.isdir(MODEL_DIR):
131
+ try:
132
+ time.sleep(0.5) # settle delay for FS
133
+ _zip_folder_atomic(MODEL_DIR, ZIP_FILE, ZIP_TEMP)
134
+ sz = _human_size(os.path.getsize(ZIP_FILE))
135
+ log.write(f"\nβœ… Model zipped β†’ {ZIP_FILE} ({sz})\n")
136
+ except Exception as e:
137
+ log.write(f"\n❌ Zipping failed: {e}\n")
138
+ else:
139
+ log.write("\n❌ Training failed; no zip created.\n")
140
+
141
+ def start_training(dataset_path: str, shards_folder: str):
142
+ ensure_clean_zip()
143
+ t = threading.Thread(target=_train_worker, args=(dataset_path, shards_folder), daemon=True)
144
+ t.start()
145
+ return "πŸš€ Training started in the background. Use the Refresh buttons to update."
146
+
147
+ def read_logs_once():
148
+ return _read_file_safely(LOG_FILE, "Waiting for logs...")
149
+
150
+ def check_download():
151
+ if os.path.exists(ZIP_FILE):
152
+ return gr.update(visible=True, value=ZIP_FILE), _download_info_text()
153
+ else:
154
+ return gr.update(visible=False, value=None), "No trained model yet."
155
+
156
+ # ---- test helpers ----
157
+ def upload_test_model_zip(zip_file):
158
+ """Upload a model ZIP and extract it to models/test_<uuid>/ for testing."""
159
+ if zip_file is None:
160
+ return "❌ No file uploaded.", ""
161
+ extract_root = os.path.join("models", f"test_{uuid.uuid4().hex}")
162
+ os.makedirs(extract_root, exist_ok=True)
163
+ try:
164
+ with zipfile.ZipFile(zip_file.name, "r") as zf:
165
+ zf.extractall(extract_root)
166
+ return f"βœ… Model ZIP extracted to: {extract_root}", extract_root
167
+ except Exception as e:
168
+ return f"❌ Failed to extract: {e}", ""
169
+
170
+ def clear_uploaded_model():
171
+ return "Model cleared. Will use trained_model/ if available.", ""
172
+
173
+ def generate_response(prompt, uploaded_model_path):
174
+ if not prompt or not prompt.strip():
175
+ return "Please enter a prompt."
176
+ try:
177
+ if uploaded_model_path and os.path.isdir(uploaded_model_path):
178
+ model_path = uploaded_model_path
179
+ src = "(uploaded model)"
180
+ elif os.path.isdir(MODEL_DIR):
181
+ model_path = MODEL_DIR
182
+ src = "(trained_model/)"
183
+ else:
184
+ model_path = "distilgpt2"
185
+ src = "(fallback: distilgpt2)"
186
+ gen = pipeline("text-generation", model=model_path, tokenizer="distilgpt2")
187
+ out = gen(prompt, max_length=256, do_sample=True, temperature=0.7, truncation=True)[0]["generated_text"]
188
+ return f"{out}\n\nβ€” using {src}"
189
+ except Exception as e:
190
+ return f"❌ Error: {e}"
191
+
192
+ # ---- UI ----
193
+ with gr.Blocks(title="JSON AI Trainer") as app:
194
+ gr.Markdown("## 🧩 JSON AI Trainer\nUpload a dataset (JSONL/JSON), train in background, download the model, and test JSON-focused prompts.")
195
+
196
+ dataset_state = gr.State(value="")
197
+ shard_folder_state = gr.State(value="")
198
+ test_model_state = gr.State(value="")
199
+
200
+ with gr.Tab("🧠 Train"):
201
+ gr.Markdown("Upload a single JSONL/JSON *or* provide a folder of shards (.jsonl, .json, .jsonl.gz, .json.gz).")
202
+ with gr.Row():
203
+ file_input = gr.File(label="Upload single dataset file", file_types=[".jsonl", ".json"])
204
+ upload_btn = gr.Button("πŸ“€ Upload (single file)")
205
+ with gr.Row():
206
+ shards_folder = gr.Textbox(value="", label="Folder with shards (optional)")
207
+ use_folder_btn = gr.Button("πŸ“‚ Use Folder For Training")
208
+ status_box = gr.Textbox(label="Status", interactive=False)
209
+
210
+ with gr.Row():
211
+ start_btn = gr.Button("πŸš€ Start Training")
212
+ refresh_btn = gr.Button("πŸ” Refresh Logs")
213
+ refresh_dl_btn = gr.Button("πŸ“¦ Refresh Download Area")
214
+ log_output = gr.Textbox(label="πŸ“œ Training Logs", lines=18)
215
+
216
+ with gr.Group():
217
+ gr.Markdown("### πŸ“¦ Trained Model")
218
+ download_info = gr.Markdown(value="No trained model yet.")
219
+ download_btn = gr.DownloadButton(label="πŸ“₯ Download Trained Model (.zip)", visible=False, value=None)
220
+
221
+ upload_btn.click(fn=upload_file, inputs=file_input, outputs=[status_box, dataset_state])
222
+ use_folder_btn.click(
223
+ fn=lambda p: ("βœ… Using folder for training." if p.strip() else "❌ Provide a valid folder path.", p.strip()),
224
+ inputs=shards_folder,
225
+ outputs=[status_box, shard_folder_state]
226
+ )
227
+ start_btn.click(
228
+ fn=start_training,
229
+ inputs=[dataset_state, shard_folder_state],
230
+ outputs=status_box
231
+ ).then(fn=read_logs_once, outputs=log_output
232
+ ).then(fn=check_download, outputs=[download_btn, download_info])
233
+
234
+ refresh_btn.click(fn=read_logs_once, outputs=log_output)
235
+ refresh_dl_btn.click(fn=check_download, outputs=[download_btn, download_info])
236
+
237
+ with gr.Tab("πŸš€ Test"):
238
+ gr.Markdown("Upload a model ZIP or use the just-trained model. Try prompts like β€œFix this JSON”, β€œGenerate JSON schema for …”")
239
+ with gr.Row():
240
+ test_zip = gr.File(label="Upload Model ZIP", file_types=[".zip"])
241
+ load_test_btn = gr.Button("πŸ“¦ Load Uploaded Model ZIP")
242
+ clear_test_btn = gr.Button("🧹 Clear Uploaded Model")
243
+ test_status = gr.Textbox(label="Test Model Status", interactive=False)
244
+
245
+ prompt_input = gr.Textbox(
246
+ label="Prompt",
247
+ placeholder='e.g., "Generate valid JSON for a product with id, name, price, tags (array of strings)"'
248
+ )
249
+ test_btn = gr.Button("πŸ” Generate")
250
+ response_output = gr.Textbox(label="AI Response", lines=12)
251
+
252
+ load_test_btn.click(fn=upload_test_model_zip, inputs=test_zip, outputs=[test_status, test_model_state])
253
+ clear_test_btn.click(fn=clear_uploaded_model, outputs=[test_status, test_model_state])
254
+ test_btn.click(fn=generate_response, inputs=[prompt_input, test_model_state], outputs=response_output)
255
+
256
+ # Optional auto-start via env vars
257
+ AUTOSTART = os.getenv("AUTOSTART_TRAIN", "0") == "1"
258
+ AUTOSTART_DATASET = os.getenv("AUTOSTART_DATASET", "").strip()
259
+ AUTOSTART_SHARDS = os.getenv("AUTOSTART_SHARDS", "").strip()
260
+ if AUTOSTART and not os.path.exists(".autostart.started"):
261
+ open(".autostart.started", "w").close()
262
+ try:
263
+ _ = start_training(AUTOSTART_DATASET if AUTOSTART_DATASET else "",
264
+ AUTOSTART_SHARDS if AUTOSTART_SHARDS else "")
265
+ _ = read_logs_once()
266
+ except Exception as e:
267
+ with open(LOG_FILE, "a") as log:
268
+ log.write(f"\n❌ Autostart failed: {e}\n")
269
+
270
+ app.launch()