Spaces:
Sleeping
Sleeping
File size: 10,582 Bytes
c7a125e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
import gradio as gr
import subprocess
import os
import sys
from datetime import datetime
import shutil
# --- CONFIGURATION UPDATED FOR HYBRID MODEL ---
TRAINING_SCRIPT = "train_hybrid.py"
MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_bilstm_crf_hybrid.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)
# ----------------------------------------------------------------
def retrieve_model():
"""
Checks for the final model file and prepares it for download.
Useful for when the training job finishes server-side but the
client connection has timed out.
"""
if os.path.exists(MODEL_FILE_PATH):
file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB
# Copy to a simple location that Gradio can reliably serve
import tempfile
temp_dir = tempfile.gettempdir()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_model_path = os.path.join(temp_dir, f"hybrid_model_recovered_{timestamp}.pth")
try:
shutil.copy2(MODEL_FILE_PATH, temp_model_path)
download_path = temp_model_path
log_output = (
f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
f"π SUCCESS! The Hybrid LayoutLMv3+BiLSTM+CRF model was found.\n"
f"π¦ Model file: {MODEL_FILE_PATH}\n"
f"π Model size: {file_size:.2f} MB\n"
f"π Download path prepared: {download_path}\n\n"
f"β¬οΈ Click the 'π₯ Download Model' button below to save your model."
)
return log_output, download_path, gr.Button(visible=True)
except Exception as e:
log_output = (
f"--- Model Status Check FAILED ---\n"
f"β οΈ Trained model found, but could not prepare for download: {e}\n"
f"π Original Path: {MODEL_FILE_PATH}. Try again or check Space logs."
)
return log_output, None, gr.Button(visible=False)
else:
log_output = (
f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
f"β Model file not found at {MODEL_FILE_PATH}.\n"
f"Training may still be running or it failed. Check back later."
)
return log_output, None, gr.Button(visible=False)
def clear_memory(dataset_file: gr.File):
"""
Deletes the model output directory and the uploaded dataset file.
"""
log_output = f"--- Memory Clear Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
# 1. Clear Model Checkpoints Directory
if os.path.exists(MODEL_OUTPUT_DIR):
try:
shutil.rmtree(MODEL_OUTPUT_DIR)
log_output += f"β
Successfully deleted model directory: {MODEL_OUTPUT_DIR}\n"
except Exception as e:
log_output += f"β ERROR deleting model directory {MODEL_OUTPUT_DIR}: {e}\n"
else:
log_output += f"βΉοΈ Model directory not found: {MODEL_OUTPUT_DIR} (Nothing to delete)\n"
# 2. Clear Uploaded Dataset File (Temporary file cleanup)
if dataset_file is not None:
input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
if os.path.exists(input_path):
try:
os.remove(input_path)
log_output += f"β
Successfully deleted uploaded dataset file: {input_path}\n"
except Exception as e:
log_output += f"β ERROR deleting dataset file {input_path}: {e}\n"
else:
log_output += f"βΉοΈ Uploaded dataset file not found at {input_path}.\n"
else:
log_output += f"βΉοΈ No dataset file currently tracked for deletion.\n"
log_output += f"--- Memory Clear Complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
log_output += "β¨ Files and checkpoints have been removed. You can now start a fresh training run."
return log_output, None, gr.Button(visible=False), None
def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
"""
Handles the Gradio submission and executes the training script using subprocess.
"""
# 1. Setup: Create output directory
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)
# 2. File Handling
if dataset_file is None:
yield "β ERROR: Please upload a file.", None, gr.Button(visible=False)
return
input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
if not os.path.exists(input_path):
yield f"β ERROR: Uploaded file not found at {input_path}.", None, gr.Button(visible=False)
return
progress(0.1, desc="Initializing Hybrid Model Training...")
log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
log_output += f"π€ Architecture: LayoutLMv3 + BiLSTM + CRF\n"
# 3. Construct the subprocess command
command = [
sys.executable,
TRAINING_SCRIPT,
"--mode", "train",
"--input", input_path,
"--batch_size", str(batch_size),
"--epochs", str(epochs),
"--lr", str(lr),
"--max_len", str(max_len)
]
log_output += f"Executing command: {' '.join(command)}\n\n"
yield log_output, None, gr.Button(visible=False)
try:
# 4. Run the training script
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
# Stream logs
for line in iter(process.stdout.readline, ""):
log_output += line
print(line, end='')
yield log_output, None, gr.Button(visible=False)
process.stdout.close()
return_code = process.wait()
# 5. Check completion
if return_code == 0:
log_output += "\n" + "=" * 60 + "\n"
log_output += "β
HYBRID TRAINING COMPLETE!\n"
log_output += "=" * 60 + "\n"
if os.path.exists(MODEL_FILE_PATH):
file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024)
log_output += f"\nπ¦ Model file found: {MODEL_FILE_PATH} ({file_size:.2f} MB)"
# Copy for download
import tempfile
temp_dir = tempfile.gettempdir()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_model_path = os.path.join(temp_dir, f"hybrid_model_{timestamp}.pth")
try:
shutil.copy2(MODEL_FILE_PATH, temp_model_path)
download_path = temp_model_path
except Exception as e:
log_output += f"\nβ οΈ Copy failed: {e}, using original path"
download_path = MODEL_FILE_PATH
log_output += f"\n\nβ¬οΈ Click the 'π₯ Download Model' button below."
yield log_output, download_path, gr.Button(visible=True)
return
else:
log_output += f"\nβ Error: Training finished but {MODEL_FILE_PATH} was not found."
yield log_output, None, gr.Button(visible=False)
return
else:
log_output += f"\nβ TRAINING FAILED with return code {return_code}\n"
yield log_output, None, gr.Button(visible=False)
return
except FileNotFoundError:
yield log_output + f"\nβ ERROR: '{TRAINING_SCRIPT}' not found.", None, gr.Button(visible=False)
except Exception as e:
yield log_output + f"\nβ Unexpected Error: {e}", None, gr.Button(visible=False)
# --- Gradio Interface Setup ---
with gr.Blocks(title="Hybrid LayoutLM Training", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 𧬠Hybrid LayoutLMv3 + BiLSTM + CRF Training")
gr.Markdown(
"""
**Architecture:** This app trains a state-of-the-art stack:
1. **LayoutLMv3** (Visual & Textual Embeddings)
2. **Bi-LSTM** (Sequence Context Modeling)
3. **CRF** (Label Consistency Enforcement)
**Instructions:** Upload your Label Studio JSON, set parameters, and train.
**Note:** This model is slower to train than standard LayoutLM but typically achieves higher accuracy on complex layouts.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### π Dataset")
file_input = gr.File(label="Upload Label Studio JSON", file_types=[".json"])
gr.Markdown("### βοΈ Hyperparameters")
batch_size_input = gr.Slider(1, 16, value=4, step=1, label="Batch Size")
epochs_input = gr.Slider(1, 10, value=5, step=1, label="Epochs")
lr_input = gr.Number(value=2e-5, label="Learning Rate (Backbone)", info="LSTM/CRF head uses 1e-4")
max_len_input = gr.Slider(128, 512, value=512, step=128, label="Max Seq Len")
train_button = gr.Button("π₯ Start Hybrid Training", variant="primary", size="lg")
check_button = gr.Button("π Check Status / Recover Model", variant="secondary")
clear_button = gr.Button("π§Ή Clear Files", variant="stop")
with gr.Column(scale=2):
log_output = gr.Textbox(
label="Training Logs", lines=25, autoscroll=True, show_copy_button=True,
placeholder="Logs will appear here..."
)
download_btn = gr.Button("π₯ Download Hybrid Model", variant="primary", size="lg", visible=False)
# State and hidden download component
model_path_state = gr.State(value=None)
model_download = gr.File(label="Download", interactive=False, visible=True)
# Actions
train_button.click(
fn=train_model,
inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
outputs=[log_output, model_path_state, download_btn]
)
check_button.click(
fn=retrieve_model,
inputs=[],
outputs=[log_output, model_path_state, download_btn]
)
clear_button.click(
fn=clear_memory,
inputs=[file_input],
outputs=[log_output, model_path_state, download_btn, model_download]
)
download_btn.click(
fn=lambda path: path,
inputs=[model_path_state],
outputs=[model_download]
)
if __name__ == "__main__":
demo.launch() |