import gradio as gr import subprocess import os import sys from datetime import datetime import shutil # --- CONFIGURATION UPDATED FOR HYBRID MODEL --- TRAINING_SCRIPT = "train_hybrid.py" MODEL_OUTPUT_DIR = "checkpoints" MODEL_FILE_NAME = "layoutlmv3_bilstm_crf_hybrid.pth" MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME) # ---------------------------------------------------------------- def retrieve_model(): """ Checks for the final model file and prepares it for download. Useful for when the training job finishes server-side but the client connection has timed out. """ if os.path.exists(MODEL_FILE_PATH): file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB # Copy to a simple location that Gradio can reliably serve import tempfile temp_dir = tempfile.gettempdir() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_model_path = os.path.join(temp_dir, f"hybrid_model_recovered_{timestamp}.pth") try: shutil.copy2(MODEL_FILE_PATH, temp_model_path) download_path = temp_model_path log_output = ( f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" f"๐ŸŽ‰ SUCCESS! The Hybrid LayoutLMv3+BiLSTM+CRF model was found.\n" f"๐Ÿ“ฆ Model file: {MODEL_FILE_PATH}\n" f"๐Ÿ“Š Model size: {file_size:.2f} MB\n" f"๐Ÿ”— Download path prepared: {download_path}\n\n" f"โฌ‡๏ธ Click the '๐Ÿ“ฅ Download Model' button below to save your model." ) return log_output, download_path, gr.Button(visible=True) except Exception as e: log_output = ( f"--- Model Status Check FAILED ---\n" f"โš ๏ธ Trained model found, but could not prepare for download: {e}\n" f"๐Ÿ“ Original Path: {MODEL_FILE_PATH}. Try again or check Space logs." ) return log_output, None, gr.Button(visible=False) else: log_output = ( f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" f"โŒ Model file not found at {MODEL_FILE_PATH}.\n" f"Training may still be running or it failed. Check back later." ) return log_output, None, gr.Button(visible=False) def clear_memory(dataset_file: gr.File): """ Deletes the model output directory and the uploaded dataset file. """ log_output = f"--- Memory Clear Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" # 1. Clear Model Checkpoints Directory if os.path.exists(MODEL_OUTPUT_DIR): try: shutil.rmtree(MODEL_OUTPUT_DIR) log_output += f"โœ… Successfully deleted model directory: {MODEL_OUTPUT_DIR}\n" except Exception as e: log_output += f"โŒ ERROR deleting model directory {MODEL_OUTPUT_DIR}: {e}\n" else: log_output += f"โ„น๏ธ Model directory not found: {MODEL_OUTPUT_DIR} (Nothing to delete)\n" # 2. Clear Uploaded Dataset File (Temporary file cleanup) if dataset_file is not None: input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file) if os.path.exists(input_path): try: os.remove(input_path) log_output += f"โœ… Successfully deleted uploaded dataset file: {input_path}\n" except Exception as e: log_output += f"โŒ ERROR deleting dataset file {input_path}: {e}\n" else: log_output += f"โ„น๏ธ Uploaded dataset file not found at {input_path}.\n" else: log_output += f"โ„น๏ธ No dataset file currently tracked for deletion.\n" log_output += f"--- Memory Clear Complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" log_output += "โœจ Files and checkpoints have been removed. You can now start a fresh training run." return log_output, None, gr.Button(visible=False), None def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()): """ Handles the Gradio submission and executes the training script using subprocess. """ # 1. Setup: Create output directory os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True) # 2. File Handling if dataset_file is None: yield "โŒ ERROR: Please upload a file.", None, gr.Button(visible=False) return input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file) if not os.path.exists(input_path): yield f"โŒ ERROR: Uploaded file not found at {input_path}.", None, gr.Button(visible=False) return progress(0.1, desc="Initializing Hybrid Model Training...") log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" log_output += f"๐Ÿค– Architecture: LayoutLMv3 + BiLSTM + CRF\n" # 3. Construct the subprocess command command = [ sys.executable, TRAINING_SCRIPT, "--mode", "train", "--input", input_path, "--batch_size", str(batch_size), "--epochs", str(epochs), "--lr", str(lr), "--max_len", str(max_len) ] log_output += f"Executing command: {' '.join(command)}\n\n" yield log_output, None, gr.Button(visible=False) try: # 4. Run the training script process = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1 ) # Stream logs for line in iter(process.stdout.readline, ""): log_output += line print(line, end='') yield log_output, None, gr.Button(visible=False) process.stdout.close() return_code = process.wait() # 5. Check completion if return_code == 0: log_output += "\n" + "=" * 60 + "\n" log_output += "โœ… HYBRID TRAINING COMPLETE!\n" log_output += "=" * 60 + "\n" if os.path.exists(MODEL_FILE_PATH): file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) log_output += f"\n๐Ÿ“ฆ Model file found: {MODEL_FILE_PATH} ({file_size:.2f} MB)" # Copy for download import tempfile temp_dir = tempfile.gettempdir() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_model_path = os.path.join(temp_dir, f"hybrid_model_{timestamp}.pth") try: shutil.copy2(MODEL_FILE_PATH, temp_model_path) download_path = temp_model_path except Exception as e: log_output += f"\nโš ๏ธ Copy failed: {e}, using original path" download_path = MODEL_FILE_PATH log_output += f"\n\nโฌ‡๏ธ Click the '๐Ÿ“ฅ Download Model' button below." yield log_output, download_path, gr.Button(visible=True) return else: log_output += f"\nโŒ Error: Training finished but {MODEL_FILE_PATH} was not found." yield log_output, None, gr.Button(visible=False) return else: log_output += f"\nโŒ TRAINING FAILED with return code {return_code}\n" yield log_output, None, gr.Button(visible=False) return except FileNotFoundError: yield log_output + f"\nโŒ ERROR: '{TRAINING_SCRIPT}' not found.", None, gr.Button(visible=False) except Exception as e: yield log_output + f"\nโŒ Unexpected Error: {e}", None, gr.Button(visible=False) # --- Gradio Interface Setup --- with gr.Blocks(title="Hybrid LayoutLM Training", theme=gr.themes.Soft()) as demo: gr.Markdown("# ๐Ÿงฌ Hybrid LayoutLMv3 + BiLSTM + CRF Training") gr.Markdown( """ **Architecture:** This app trains a state-of-the-art stack: 1. **LayoutLMv3** (Visual & Textual Embeddings) 2. **Bi-LSTM** (Sequence Context Modeling) 3. **CRF** (Label Consistency Enforcement) **Instructions:** Upload your Label Studio JSON, set parameters, and train. **Note:** This model is slower to train than standard LayoutLM but typically achieves higher accuracy on complex layouts. """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### ๐Ÿ“ Dataset") file_input = gr.File(label="Upload Label Studio JSON", file_types=[".json"]) gr.Markdown("### โš™๏ธ Hyperparameters") batch_size_input = gr.Slider(1, 16, value=4, step=1, label="Batch Size") epochs_input = gr.Slider(1, 10, value=5, step=1, label="Epochs") lr_input = gr.Number(value=2e-5, label="Learning Rate (Backbone)", info="LSTM/CRF head uses 1e-4") max_len_input = gr.Slider(128, 512, value=512, step=128, label="Max Seq Len") train_button = gr.Button("๐Ÿ”ฅ Start Hybrid Training", variant="primary", size="lg") check_button = gr.Button("๐Ÿ” Check Status / Recover Model", variant="secondary") clear_button = gr.Button("๐Ÿงน Clear Files", variant="stop") with gr.Column(scale=2): log_output = gr.Textbox( label="Training Logs", lines=25, autoscroll=True, show_copy_button=True, placeholder="Logs will appear here..." ) download_btn = gr.Button("๐Ÿ“ฅ Download Hybrid Model", variant="primary", size="lg", visible=False) # State and hidden download component model_path_state = gr.State(value=None) model_download = gr.File(label="Download", interactive=False, visible=True) # Actions train_button.click( fn=train_model, inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input], outputs=[log_output, model_path_state, download_btn] ) check_button.click( fn=retrieve_model, inputs=[], outputs=[log_output, model_path_state, download_btn] ) clear_button.click( fn=clear_memory, inputs=[file_input], outputs=[log_output, model_path_state, download_btn, model_download] ) download_btn.click( fn=lambda path: path, inputs=[model_path_state], outputs=[model_download] ) if __name__ == "__main__": demo.launch()