aagamjtdev commited on
Commit
9b410c3
Β·
1 Parent(s): 005126e

Gradio App

Browse files
Files changed (1) hide show
  1. app.py +743 -0
app.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import shutil
3
+ # import tempfile
4
+ # import gradio as gr
5
+ # from huggingface_hub import hf_hub_download, upload_file, HfApi
6
+ # import subprocess
7
+ # import sys
8
+ #
9
+ # # Configuration
10
+ # OUTPUT_DIR = "output2_data"
11
+ # MODEL_FILE = "model_enhanced.pt"
12
+ # VOCAB_FILE = "vocabs_enhanced.pkl"
13
+ # REPO_ID = os.environ.get("SPACE_ID", "heerjtdev/LSTM_CRF") # Replace with your repo ID
14
+ # HF_TOKEN = os.environ.get("HF_TOKEN") # Set this as a secret in your Space settings
15
+ #
16
+ #
17
+ # def download_existing_models():
18
+ # """Download existing model files from the Hugging Face Hub if available."""
19
+ # try:
20
+ # api = HfApi()
21
+ # files = api.list_repo_files(REPO_ID, token=HF_TOKEN)
22
+ #
23
+ # os.makedirs(OUTPUT_DIR, exist_ok=True)
24
+ #
25
+ # downloaded_files = []
26
+ # if MODEL_FILE in files:
27
+ # model_path = hf_hub_download(
28
+ # repo_id=REPO_ID,
29
+ # filename=MODEL_FILE,
30
+ # token=HF_TOKEN,
31
+ # local_dir=OUTPUT_DIR,
32
+ # local_dir_use_symlinks=False
33
+ # )
34
+ # downloaded_files.append(MODEL_FILE)
35
+ #
36
+ # if VOCAB_FILE in files:
37
+ # vocab_path = hf_hub_download(
38
+ # repo_id=REPO_ID,
39
+ # filename=VOCAB_FILE,
40
+ # token=HF_TOKEN,
41
+ # local_dir=OUTPUT_DIR,
42
+ # local_dir_use_symlinks=False
43
+ # )
44
+ # downloaded_files.append(VOCAB_FILE)
45
+ #
46
+ # if downloaded_files:
47
+ # return f"βœ… Downloaded existing files: {', '.join(downloaded_files)}"
48
+ # else:
49
+ # return "ℹ️ No existing model files found in repository."
50
+ # except Exception as e:
51
+ # return f"⚠️ Could not download existing models: {str(e)}"
52
+ #
53
+ #
54
+ # def train_model(dataset_file, progress=gr.Progress()):
55
+ # """Train the model with the uploaded dataset."""
56
+ # if dataset_file is None:
57
+ # return "❌ Please upload a dataset file!", None, None
58
+ #
59
+ # try:
60
+ # # Step 1: Download existing models (if any)
61
+ # progress(0.1, desc="Checking for existing models...")
62
+ # download_status = download_existing_models()
63
+ # yield f"πŸ“₯ {download_status}\n", None, None
64
+ #
65
+ # # Step 2: Save uploaded file
66
+ # progress(0.2, desc="Processing dataset...")
67
+ # dataset_path = dataset_file.name
68
+ # yield f"πŸ“₯ {download_status}\nπŸ“‚ Dataset uploaded: {os.path.basename(dataset_path)}\n", None, None
69
+ #
70
+ # # Step 3: Import and run training
71
+ # progress(0.3, desc="Starting training...")
72
+ # yield f"πŸ“₯ {download_status}\nπŸ“‚ Dataset uploaded: {os.path.basename(dataset_path)}\nπŸš€ Training started...\n", None, None
73
+ #
74
+ # # Import the training function
75
+ # try:
76
+ # # Import your training script (assumes it's named train_model.py)
77
+ # import train_model as tm
78
+ #
79
+ # # Run training
80
+ # progress(0.4, desc="Training in progress...")
81
+ # tm.train_from_json(dataset_path)
82
+ #
83
+ # yield f"πŸ“₯ {download_status}\nπŸ“‚ Dataset uploaded: {os.path.basename(dataset_path)}\nβœ… Training completed!\n", None, None
84
+ #
85
+ # except ImportError:
86
+ # # If direct import fails, try running as subprocess
87
+ # progress(0.4, desc="Training in progress...")
88
+ # result = subprocess.run(
89
+ # [sys.executable, "train_model.py", dataset_path],
90
+ # capture_output=True,
91
+ # text=True
92
+ # )
93
+ #
94
+ # if result.returncode != 0:
95
+ # yield f"❌ Training failed:\n{result.stderr}", None, None
96
+ # return
97
+ #
98
+ # yield f"πŸ“₯ {download_status}\nπŸ“‚ Dataset uploaded: {os.path.basename(dataset_path)}\nβœ… Training completed!\n", None, None
99
+ #
100
+ # # Step 4: Upload trained models to Hub
101
+ # progress(0.8, desc="Uploading models to Hub...")
102
+ # model_path = os.path.join(OUTPUT_DIR, MODEL_FILE)
103
+ # vocab_path = os.path.join(OUTPUT_DIR, VOCAB_FILE)
104
+ #
105
+ # upload_status = []
106
+ # if os.path.exists(model_path):
107
+ # upload_file(
108
+ # path_or_fileobj=model_path,
109
+ # path_in_repo=MODEL_FILE,
110
+ # repo_id=REPO_ID,
111
+ # token=HF_TOKEN
112
+ # )
113
+ # upload_status.append(MODEL_FILE)
114
+ #
115
+ # if os.path.exists(vocab_path):
116
+ # upload_file(
117
+ # path_or_fileobj=vocab_path,
118
+ # path_in_repo=VOCAB_FILE,
119
+ # repo_id=REPO_ID,
120
+ # token=HF_TOKEN
121
+ # )
122
+ # upload_status.append(VOCAB_FILE)
123
+ #
124
+ # # Step 5: Copy to temp directory for download
125
+ # progress(0.9, desc="Preparing downloads...")
126
+ # temp_dir = tempfile.mkdtemp()
127
+ #
128
+ # model_download = None
129
+ # vocab_download = None
130
+ #
131
+ # if os.path.exists(model_path):
132
+ # temp_model = os.path.join(temp_dir, MODEL_FILE)
133
+ # shutil.copy2(model_path, temp_model)
134
+ # model_download = temp_model
135
+ #
136
+ # if os.path.exists(vocab_path):
137
+ # temp_vocab = os.path.join(temp_dir, VOCAB_FILE)
138
+ # shutil.copy2(vocab_path, temp_vocab)
139
+ # vocab_download = temp_vocab
140
+ #
141
+ # progress(1.0, desc="Complete!")
142
+ #
143
+ # final_message = (
144
+ # f"πŸ“₯ {download_status}\n"
145
+ # f"πŸ“‚ Dataset uploaded: {os.path.basename(dataset_path)}\n"
146
+ # f"βœ… Training completed!\n"
147
+ # f"☁️ Uploaded to Hub: {', '.join(upload_status)}\n"
148
+ # f"πŸ“¦ Files ready for download!"
149
+ # )
150
+ #
151
+ # yield final_message, model_download, vocab_download
152
+ #
153
+ # except Exception as e:
154
+ # yield f"❌ Error during training: {str(e)}", None, None
155
+ #
156
+ #
157
+ # def download_models_from_hub():
158
+ # """Download the latest models from the Hugging Face Hub."""
159
+ # try:
160
+ # os.makedirs(OUTPUT_DIR, exist_ok=True)
161
+ #
162
+ # # Download model
163
+ # model_path = hf_hub_download(
164
+ # repo_id=REPO_ID,
165
+ # filename=MODEL_FILE,
166
+ # token=HF_TOKEN,
167
+ # local_dir=OUTPUT_DIR,
168
+ # local_dir_use_symlinks=False,
169
+ # force_download=True
170
+ # )
171
+ #
172
+ # # Download vocab
173
+ # vocab_path = hf_hub_download(
174
+ # repo_id=REPO_ID,
175
+ # filename=VOCAB_FILE,
176
+ # token=HF_TOKEN,
177
+ # local_dir=OUTPUT_DIR,
178
+ # local_dir_use_symlinks=False,
179
+ # force_download=True
180
+ # )
181
+ #
182
+ # # Copy to temp for download
183
+ # temp_dir = tempfile.mkdtemp()
184
+ # temp_model = os.path.join(temp_dir, MODEL_FILE)
185
+ # temp_vocab = os.path.join(temp_dir, VOCAB_FILE)
186
+ #
187
+ # shutil.copy2(model_path, temp_model)
188
+ # shutil.copy2(vocab_path, temp_vocab)
189
+ #
190
+ # return (
191
+ # "βœ… Successfully downloaded models from Hugging Face Hub!",
192
+ # temp_model,
193
+ # temp_vocab
194
+ # )
195
+ # except Exception as e:
196
+ # return f"❌ Error downloading models: {str(e)}", None, None
197
+ #
198
+ #
199
+ # # Create Gradio interface
200
+ # with gr.Blocks(title="MCQ Structure Extraction - Model Training", theme=gr.themes.Soft()) as demo:
201
+ # gr.Markdown(
202
+ # """
203
+ # # πŸŽ“ MCQ Structure Extraction - Model Training
204
+ #
205
+ # Train a BiLSTM-CRF model with deep layout understanding for extracting structured information from MCQ documents.
206
+ #
207
+ # ## πŸ“‹ Instructions:
208
+ # 1. **Upload Dataset**: Provide your unified JSON file containing tokens, bounding boxes, and labels
209
+ # 2. **Train Model**: Click "Start Training" and wait for completion (this may take a while)
210
+ # 3. **Download Models**: Once training is complete, download the trained model and vocabulary files
211
+ #
212
+ # ## πŸ“₯ Or Download Existing Models:
213
+ # If you just want to download the latest trained models from the repository, use the "Download from Hub" button.
214
+ # """
215
+ # )
216
+ #
217
+ # with gr.Tab("Train New Model"):
218
+ # with gr.Row():
219
+ # with gr.Column():
220
+ # dataset_input = gr.File(
221
+ # label="Upload Training Dataset (JSON)",
222
+ # file_types=[".json"],
223
+ # type="filepath"
224
+ # )
225
+ # train_button = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
226
+ #
227
+ # with gr.Column():
228
+ # status_output = gr.Textbox(
229
+ # label="Training Status",
230
+ # lines=8,
231
+ # interactive=False
232
+ # )
233
+ #
234
+ # with gr.Row():
235
+ # model_output = gr.File(label="πŸ“₯ Download Trained Model (.pt)")
236
+ # vocab_output = gr.File(label="πŸ“₯ Download Vocabulary (.pkl)")
237
+ #
238
+ # train_button.click(
239
+ # fn=train_model,
240
+ # inputs=[dataset_input],
241
+ # outputs=[status_output, model_output, vocab_output]
242
+ # )
243
+ #
244
+ # with gr.Tab("Download from Hub"):
245
+ # gr.Markdown(
246
+ # """
247
+ # Download the latest trained models directly from the Hugging Face Hub.
248
+ # This is useful if you want to use pre-trained models without training from scratch.
249
+ # """
250
+ # )
251
+ #
252
+ # download_button = gr.Button("☁️ Download from Hugging Face Hub", variant="primary", size="lg")
253
+ #
254
+ # download_status = gr.Textbox(
255
+ # label="Download Status",
256
+ # lines=3,
257
+ # interactive=False
258
+ # )
259
+ #
260
+ # with gr.Row():
261
+ # hub_model_output = gr.File(label="πŸ“₯ Model File (.pt)")
262
+ # hub_vocab_output = gr.File(label="πŸ“₯ Vocabulary File (.pkl)")
263
+ #
264
+ # download_button.click(
265
+ # fn=download_models_from_hub,
266
+ # outputs=[download_status, hub_model_output, hub_vocab_output]
267
+ # )
268
+ #
269
+ # gr.Markdown(
270
+ # """
271
+ # ---
272
+ # ### βš™οΈ Model Configuration:
273
+ # - **Architecture**: BiLSTM-CRF with spatial attention
274
+ # - **Features**: Word embeddings, character CNN, bounding box encoding, spatial & context features
275
+ # - **Output**: 13 entity labels (Questions, Options, Answers, Images, Section Headings, Passages)
276
+ #
277
+ # ### πŸ“Š Training Details:
278
+ # - Batch Size: 8
279
+ # - Epochs: 10 (with early stopping)
280
+ # - Learning Rate: 5e-4 (with OneCycleLR scheduler)
281
+ # - Optimizer: AdamW with weight decay
282
+ #
283
+ # **Note**: Training requires a GPU for reasonable speed. CPU training is supported but will be significantly slower.
284
+ # """
285
+ # )
286
+ #
287
+ # # Launch the app
288
+ # if __name__ == "__main__":
289
+ # demo.launch()
290
+
291
+
292
+ import os
293
+ import shutil
294
+ import tempfile
295
+ import gradio as gr
296
+ from huggingface_hub import hf_hub_download, upload_file, HfApi
297
+ import sys
298
+
299
+ # Add current directory to path to import train_model
300
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
301
+
302
+ # Configuration
303
+ OUTPUT_DIR = "output_data"
304
+ MODEL_FILE = "model_enhanced.pt"
305
+ VOCAB_FILE = "vocabs_enhanced.pkl"
306
+ CHECKPOINT_FILE = "checkpoint_enhanced.pt"
307
+
308
+ # IMPORTANT: Update this with your actual Hugging Face repository ID
309
+ REPO_ID = os.environ.get("SPACE_ID", "heerjtdev/LSTM_CRF") # Replace with your repo ID
310
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Set this as a secret in your Space settings
311
+
312
+
313
+ def download_existing_models():
314
+ """Download existing model files from the Hugging Face Hub if available."""
315
+ try:
316
+ api = HfApi()
317
+ files = api.list_repo_files(REPO_ID, token=HF_TOKEN)
318
+
319
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
320
+
321
+ downloaded_files = []
322
+
323
+ # Download model file
324
+ if MODEL_FILE in files:
325
+ print(f"πŸ“₯ Downloading {MODEL_FILE} from Hub...")
326
+ model_path = hf_hub_download(
327
+ repo_id=REPO_ID,
328
+ filename=MODEL_FILE,
329
+ token=HF_TOKEN,
330
+ local_dir=OUTPUT_DIR,
331
+ force_download=True # Always get latest version
332
+ )
333
+ downloaded_files.append(MODEL_FILE)
334
+ print(f"βœ… Downloaded {MODEL_FILE}")
335
+
336
+ # Download vocab file
337
+ if VOCAB_FILE in files:
338
+ print(f"πŸ“₯ Downloading {VOCAB_FILE} from Hub...")
339
+ vocab_path = hf_hub_download(
340
+ repo_id=REPO_ID,
341
+ filename=VOCAB_FILE,
342
+ token=HF_TOKEN,
343
+ local_dir=OUTPUT_DIR,
344
+ force_download=True # Always get latest version
345
+ )
346
+ downloaded_files.append(VOCAB_FILE)
347
+ print(f"βœ… Downloaded {VOCAB_FILE}")
348
+
349
+ # Download checkpoint file (optional, for resuming training)
350
+ if CHECKPOINT_FILE in files:
351
+ print(f"πŸ“₯ Downloading {CHECKPOINT_FILE} from Hub...")
352
+ checkpoint_path = hf_hub_download(
353
+ repo_id=REPO_ID,
354
+ filename=CHECKPOINT_FILE,
355
+ token=HF_TOKEN,
356
+ local_dir=OUTPUT_DIR,
357
+ force_download=True
358
+ )
359
+ downloaded_files.append(CHECKPOINT_FILE)
360
+ print(f"βœ… Downloaded {CHECKPOINT_FILE}")
361
+
362
+ if downloaded_files:
363
+ return f"βœ… Downloaded from Hub: {', '.join(downloaded_files)}"
364
+ else:
365
+ return "ℹ️ No existing model files found in repository. Starting fresh."
366
+ except Exception as e:
367
+ error_msg = f"⚠️ Could not download existing models: {str(e)}"
368
+ print(error_msg)
369
+ return error_msg
370
+
371
+
372
+ def train_model(dataset_file, progress=gr.Progress()):
373
+ """Train the model with the uploaded dataset."""
374
+ if dataset_file is None:
375
+ return "❌ Please upload a dataset file!", None, None
376
+
377
+ try:
378
+ # Step 1: Download existing models from Hub (if any) BEFORE training starts
379
+ progress(0.05, desc="Checking Hugging Face Hub for existing models...")
380
+ download_status = download_existing_models()
381
+ status_log = f"{download_status}\n\n"
382
+ yield status_log, None, None
383
+
384
+ # Step 2: Save uploaded file
385
+ progress(0.1, desc="Processing uploaded dataset...")
386
+ dataset_path = dataset_file.name
387
+ status_log += f"πŸ“‚ Dataset uploaded: {os.path.basename(dataset_path)}\n\n"
388
+ yield status_log, None, None
389
+
390
+ # Step 3: Import and run training
391
+ progress(0.15, desc="Initializing training...")
392
+ status_log += "πŸš€ Starting training...\n"
393
+ status_log += "πŸ“Š This may take a while. Training progress will appear in the terminal.\n\n"
394
+ yield status_log, None, None
395
+
396
+ # Import the training module
397
+ try:
398
+ import train_model as tm
399
+ print("=" * 80)
400
+ print("TRAINING STARTED")
401
+ print("=" * 80)
402
+
403
+ # Run training - this will handle model loading internally
404
+ progress(0.2, desc="Training in progress... (check terminal for details)")
405
+ tm.train_from_json(dataset_path)
406
+
407
+ print("=" * 80)
408
+ print("TRAINING COMPLETED")
409
+ print("=" * 80)
410
+
411
+ status_log += "βœ… Training completed successfully!\n\n"
412
+ yield status_log, None, None
413
+
414
+ except ImportError as ie:
415
+ error_msg = f"❌ Failed to import training module: {str(ie)}\n"
416
+ error_msg += "Make sure train_model.py is in the same directory as app.py"
417
+ yield status_log + error_msg, None, None
418
+ return
419
+ except Exception as train_error:
420
+ error_msg = f"❌ Training failed with error:\n{str(train_error)}\n"
421
+ yield status_log + error_msg, None, None
422
+ return
423
+
424
+ # Step 4: Verify files exist
425
+ progress(0.85, desc="Verifying trained model files...")
426
+ model_path = os.path.join(OUTPUT_DIR, MODEL_FILE)
427
+ vocab_path = os.path.join(OUTPUT_DIR, VOCAB_FILE)
428
+ checkpoint_path = os.path.join(OUTPUT_DIR, CHECKPOINT_FILE)
429
+
430
+ files_exist = []
431
+ if os.path.exists(model_path):
432
+ files_exist.append(MODEL_FILE)
433
+ if os.path.exists(vocab_path):
434
+ files_exist.append(VOCAB_FILE)
435
+
436
+ if not files_exist:
437
+ error_msg = "❌ Error: Model files were not created. Check training logs."
438
+ yield status_log + error_msg, None, None
439
+ return
440
+
441
+ status_log += f"βœ… Found trained files: {', '.join(files_exist)}\n\n"
442
+ yield status_log, None, None
443
+
444
+ # Step 5: Upload to Hub
445
+ progress(0.9, desc="Uploading models to Hugging Face Hub...")
446
+ status_log += "☁️ Uploading to Hugging Face Hub...\n"
447
+ yield status_log, None, None
448
+
449
+ upload_status = []
450
+
451
+ if os.path.exists(model_path):
452
+ try:
453
+ upload_file(
454
+ path_or_fileobj=model_path,
455
+ path_in_repo=MODEL_FILE,
456
+ repo_id=REPO_ID,
457
+ token=HF_TOKEN,
458
+ commit_message="Update trained model"
459
+ )
460
+ upload_status.append(MODEL_FILE)
461
+ print(f"βœ… Uploaded {MODEL_FILE} to Hub")
462
+ except Exception as e:
463
+ print(f"⚠️ Failed to upload {MODEL_FILE}: {e}")
464
+
465
+ if os.path.exists(vocab_path):
466
+ try:
467
+ upload_file(
468
+ path_or_fileobj=vocab_path,
469
+ path_in_repo=VOCAB_FILE,
470
+ repo_id=REPO_ID,
471
+ token=HF_TOKEN,
472
+ commit_message="Update vocabulary"
473
+ )
474
+ upload_status.append(VOCAB_FILE)
475
+ print(f"βœ… Uploaded {VOCAB_FILE} to Hub")
476
+ except Exception as e:
477
+ print(f"⚠️ Failed to upload {VOCAB_FILE}: {e}")
478
+
479
+ # Also upload checkpoint for future resume capability
480
+ if os.path.exists(checkpoint_path):
481
+ try:
482
+ upload_file(
483
+ path_or_fileobj=checkpoint_path,
484
+ path_in_repo=CHECKPOINT_FILE,
485
+ repo_id=REPO_ID,
486
+ token=HF_TOKEN,
487
+ commit_message="Update checkpoint"
488
+ )
489
+ upload_status.append(CHECKPOINT_FILE)
490
+ print(f"βœ… Uploaded {CHECKPOINT_FILE} to Hub")
491
+ except Exception as e:
492
+ print(f"⚠️ Failed to upload {CHECKPOINT_FILE}: {e}")
493
+
494
+ if upload_status:
495
+ status_log += f"βœ… Uploaded to Hub: {', '.join(upload_status)}\n\n"
496
+ else:
497
+ status_log += "⚠️ Warning: No files were uploaded to Hub\n\n"
498
+
499
+ yield status_log, None, None
500
+
501
+ # Step 6: Copy to temp directory for download
502
+ progress(0.95, desc="Preparing download files...")
503
+ temp_dir = tempfile.mkdtemp()
504
+
505
+ model_download = None
506
+ vocab_download = None
507
+
508
+ if os.path.exists(model_path):
509
+ temp_model = os.path.join(temp_dir, MODEL_FILE)
510
+ shutil.copy2(model_path, temp_model)
511
+ model_download = temp_model
512
+ print(f"πŸ“¦ Prepared {MODEL_FILE} for download")
513
+
514
+ if os.path.exists(vocab_path):
515
+ temp_vocab = os.path.join(temp_dir, VOCAB_FILE)
516
+ shutil.copy2(vocab_path, temp_vocab)
517
+ vocab_download = temp_vocab
518
+ print(f"πŸ“¦ Prepared {VOCAB_FILE} for download")
519
+
520
+ progress(1.0, desc="Complete!")
521
+
522
+ status_log += "πŸ“¦ Files ready for download below!\n"
523
+ status_log += "\n" + "=" * 50 + "\n"
524
+ status_log += "TRAINING COMPLETE - You can now download the model files\n"
525
+ status_log += "=" * 50
526
+
527
+ yield status_log, model_download, vocab_download
528
+
529
+ except Exception as e:
530
+ error_msg = f"❌ Unexpected error: {str(e)}\n"
531
+ import traceback
532
+ error_msg += f"\nTraceback:\n{traceback.format_exc()}"
533
+ yield error_msg, None, None
534
+
535
+
536
+ def download_models_from_hub():
537
+ """Download the latest models from the Hugging Face Hub."""
538
+ try:
539
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
540
+
541
+ api = HfApi()
542
+ files = api.list_repo_files(REPO_ID, token=HF_TOKEN)
543
+
544
+ downloaded_files = []
545
+
546
+ # Download model
547
+ if MODEL_FILE in files:
548
+ print(f"πŸ“₯ Downloading {MODEL_FILE} from Hub...")
549
+ model_path = hf_hub_download(
550
+ repo_id=REPO_ID,
551
+ filename=MODEL_FILE,
552
+ token=HF_TOKEN,
553
+ local_dir=OUTPUT_DIR,
554
+ force_download=True
555
+ )
556
+ downloaded_files.append(MODEL_FILE)
557
+ else:
558
+ return f"❌ {MODEL_FILE} not found in repository", None, None
559
+
560
+ # Download vocab
561
+ if VOCAB_FILE in files:
562
+ print(f"πŸ“₯ Downloading {VOCAB_FILE} from Hub...")
563
+ vocab_path = hf_hub_download(
564
+ repo_id=REPO_ID,
565
+ filename=VOCAB_FILE,
566
+ token=HF_TOKEN,
567
+ local_dir=OUTPUT_DIR,
568
+ force_download=True
569
+ )
570
+ downloaded_files.append(VOCAB_FILE)
571
+ else:
572
+ return f"❌ {VOCAB_FILE} not found in repository", None, None
573
+
574
+ # Copy to temp for download
575
+ temp_dir = tempfile.mkdtemp()
576
+ temp_model = os.path.join(temp_dir, MODEL_FILE)
577
+ temp_vocab = os.path.join(temp_dir, VOCAB_FILE)
578
+
579
+ shutil.copy2(os.path.join(OUTPUT_DIR, MODEL_FILE), temp_model)
580
+ shutil.copy2(os.path.join(OUTPUT_DIR, VOCAB_FILE), temp_vocab)
581
+
582
+ success_msg = f"βœ… Successfully downloaded from Hub:\n"
583
+ success_msg += f" β€’ {MODEL_FILE}\n"
584
+ success_msg += f" β€’ {VOCAB_FILE}\n\n"
585
+ success_msg += "πŸ“¦ Files are ready to download below!"
586
+
587
+ return success_msg, temp_model, temp_vocab
588
+
589
+ except Exception as e:
590
+ error_msg = f"❌ Error downloading models: {str(e)}\n\n"
591
+ error_msg += f"Make sure:\n"
592
+ error_msg += f"1. REPO_ID is set correctly: {REPO_ID}\n"
593
+ error_msg += f"2. HF_TOKEN is set in Space secrets\n"
594
+ error_msg += f"3. Model files exist in the repository"
595
+ return error_msg, None, None
596
+
597
+
598
+ # Create Gradio interface
599
+ with gr.Blocks(title="MCQ Structure Extraction - Model Training", theme=gr.themes.Soft()) as demo:
600
+ gr.Markdown(
601
+ """
602
+ # πŸŽ“ MCQ Structure Extraction - Model Training
603
+
604
+ Train a BiLSTM-CRF model with deep layout understanding for extracting structured information from MCQ documents.
605
+
606
+ ## πŸ“‹ Instructions:
607
+ 1. **Upload Dataset**: Provide your unified JSON file containing tokens, bounding boxes, and labels
608
+ 2. **Train Model**: Click "Start Training" and wait for completion (this may take a while)
609
+ 3. **Download Models**: Once training is complete, download the trained model and vocabulary files
610
+
611
+ ## πŸ“₯ Or Download Existing Models:
612
+ If you just want to download the latest trained models from the repository, use the "Download from Hub" tab.
613
+
614
+ ---
615
+ """
616
+ )
617
+
618
+ with gr.Tab("πŸš€ Train New Model"):
619
+ gr.Markdown(
620
+ """
621
+ ### Training Process:
622
+ The app will automatically:
623
+ 1. βœ… Download any existing models from Hugging Face Hub (for resuming training)
624
+ 2. 🎯 Train the model on your uploaded dataset
625
+ 3. ☁️ Upload the trained models back to the Hub
626
+ 4. πŸ“₯ Provide download links for the trained files
627
+
628
+ **Note**: Training progress details appear in the terminal/logs. The status box shows major milestones.
629
+ """
630
+ )
631
+
632
+ with gr.Row():
633
+ with gr.Column():
634
+ dataset_input = gr.File(
635
+ label="πŸ“‚ Upload Training Dataset (JSON)",
636
+ file_types=[".json"],
637
+ type="filepath"
638
+ )
639
+ train_button = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
640
+
641
+ with gr.Column():
642
+ status_output = gr.Textbox(
643
+ label="πŸ“Š Training Status",
644
+ lines=12,
645
+ interactive=False,
646
+ show_copy_button=True
647
+ )
648
+
649
+ gr.Markdown("### πŸ“¦ Download Trained Models")
650
+ with gr.Row():
651
+ model_output = gr.File(label="πŸ’Ύ Model File (.pt)")
652
+ vocab_output = gr.File(label="πŸ“š Vocabulary File (.pkl)")
653
+
654
+ train_button.click(
655
+ fn=train_model,
656
+ inputs=[dataset_input],
657
+ outputs=[status_output, model_output, vocab_output]
658
+ )
659
+
660
+ with gr.Tab("☁️ Download from Hub"):
661
+ gr.Markdown(
662
+ """
663
+ ### Download Pre-trained Models
664
+
665
+ Download the latest trained models directly from your Hugging Face repository.
666
+ This is useful if:
667
+ - You want to use pre-trained models without training
668
+ - You need to download models trained in a previous session
669
+ - You want to get the latest version from the Hub
670
+
671
+ The downloaded files can be used for inference with your MCQ extraction pipeline.
672
+ """
673
+ )
674
+
675
+ download_button = gr.Button("☁️ Download Latest Models from Hub", variant="primary", size="lg")
676
+
677
+ download_status = gr.Textbox(
678
+ label="Download Status",
679
+ lines=6,
680
+ interactive=False,
681
+ show_copy_button=True
682
+ )
683
+
684
+ gr.Markdown("### πŸ“¦ Downloaded Files")
685
+ with gr.Row():
686
+ hub_model_output = gr.File(label="πŸ’Ύ Model File (.pt)")
687
+ hub_vocab_output = gr.File(label="πŸ“š Vocabulary File (.pkl)")
688
+
689
+ download_button.click(
690
+ fn=download_models_from_hub,
691
+ outputs=[download_status, hub_model_output, hub_vocab_output]
692
+ )
693
+
694
+ gr.Markdown(
695
+ """
696
+ ---
697
+ ### βš™οΈ Model Configuration:
698
+
699
+ **Architecture:**
700
+ - BiLSTM-CRF with spatial attention mechanism
701
+ - Word embeddings + Character-level CNN
702
+ - Bounding box encoding with MLP
703
+ - Spatial & context feature extraction
704
+ - Learnable positional embeddings
705
+
706
+ **Features Used:**
707
+ - Token text (word-level and character-level)
708
+ - Bounding box coordinates (normalized)
709
+ - Spatial features: vertical spacing, alignment, dimensions (11 features)
710
+ - Context features: surrounding question/option markers (8 features)
711
+
712
+ **Output Labels (13 total):**
713
+ - Questions, Options, Answers, Images, Section Headings, Passages (BIO tagging)
714
+
715
+ **Training Parameters:**
716
+ - Batch Size: 8
717
+ - Epochs: 10 (with early stopping after 10 epochs without improvement)
718
+ - Learning Rate: 5e-4 (AdamW optimizer with OneCycleLR scheduler)
719
+ - Hidden Size: 768
720
+ - Total Parameters: ~15.6M
721
+
722
+ **Hardware Requirements:**
723
+ - GPU recommended for reasonable training speed
724
+ - CPU training supported but significantly slower
725
+
726
+ ---
727
+
728
+ ### πŸ”§ Setup Notes:
729
+
730
+ **Environment Variables Required:**
731
+ - `SPACE_ID`: Your Hugging Face Space/Repo ID (auto-set in Spaces)
732
+ - `HF_TOKEN`: Your Hugging Face write token (set as a secret)
733
+
734
+ **Model Persistence:**
735
+ - Models are automatically saved to `output_data/` directory
736
+ - Best model is uploaded to Hugging Face Hub after each improvement
737
+ - Training can be resumed from checkpoints
738
+ """
739
+ )
740
+
741
+ # Launch the app
742
+ if __name__ == "__main__":
743
+ demo.launch()