Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import subprocess | |
| import os | |
| import sys | |
| from datetime import datetime | |
| import shutil | |
| # --- CONFIGURATION UPDATED FOR HYBRID MODEL --- | |
| TRAINING_SCRIPT = "train_hybrid.py" | |
| MODEL_OUTPUT_DIR = "checkpoints" | |
| MODEL_FILE_NAME = "layoutlmv3_bilstm_crf_hybrid.pth" | |
| MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME) | |
| # ---------------------------------------------------------------- | |
| def retrieve_model(): | |
| """ | |
| Checks for the final model file and prepares it for download. | |
| Useful for when the training job finishes server-side but the | |
| client connection has timed out. | |
| """ | |
| if os.path.exists(MODEL_FILE_PATH): | |
| file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) # Size in MB | |
| # Copy to a simple location that Gradio can reliably serve | |
| import tempfile | |
| temp_dir = tempfile.gettempdir() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| temp_model_path = os.path.join(temp_dir, f"hybrid_model_recovered_{timestamp}.pth") | |
| try: | |
| shutil.copy2(MODEL_FILE_PATH, temp_model_path) | |
| download_path = temp_model_path | |
| log_output = ( | |
| f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" | |
| f"π SUCCESS! The Hybrid LayoutLMv3+BiLSTM+CRF model was found.\n" | |
| f"π¦ Model file: {MODEL_FILE_PATH}\n" | |
| f"π Model size: {file_size:.2f} MB\n" | |
| f"π Download path prepared: {download_path}\n\n" | |
| f"β¬οΈ Click the 'π₯ Download Model' button below to save your model." | |
| ) | |
| return log_output, download_path, gr.Button(visible=True) | |
| except Exception as e: | |
| log_output = ( | |
| f"--- Model Status Check FAILED ---\n" | |
| f"β οΈ Trained model found, but could not prepare for download: {e}\n" | |
| f"π Original Path: {MODEL_FILE_PATH}. Try again or check Space logs." | |
| ) | |
| return log_output, None, gr.Button(visible=False) | |
| else: | |
| log_output = ( | |
| f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" | |
| f"β Model file not found at {MODEL_FILE_PATH}.\n" | |
| f"Training may still be running or it failed. Check back later." | |
| ) | |
| return log_output, None, gr.Button(visible=False) | |
| def clear_memory(dataset_file: gr.File): | |
| """ | |
| Deletes the model output directory and the uploaded dataset file. | |
| """ | |
| log_output = f"--- Memory Clear Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" | |
| # 1. Clear Model Checkpoints Directory | |
| if os.path.exists(MODEL_OUTPUT_DIR): | |
| try: | |
| shutil.rmtree(MODEL_OUTPUT_DIR) | |
| log_output += f"β Successfully deleted model directory: {MODEL_OUTPUT_DIR}\n" | |
| except Exception as e: | |
| log_output += f"β ERROR deleting model directory {MODEL_OUTPUT_DIR}: {e}\n" | |
| else: | |
| log_output += f"βΉοΈ Model directory not found: {MODEL_OUTPUT_DIR} (Nothing to delete)\n" | |
| # 2. Clear Uploaded Dataset File (Temporary file cleanup) | |
| if dataset_file is not None: | |
| input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file) | |
| if os.path.exists(input_path): | |
| try: | |
| os.remove(input_path) | |
| log_output += f"β Successfully deleted uploaded dataset file: {input_path}\n" | |
| except Exception as e: | |
| log_output += f"β ERROR deleting dataset file {input_path}: {e}\n" | |
| else: | |
| log_output += f"βΉοΈ Uploaded dataset file not found at {input_path}.\n" | |
| else: | |
| log_output += f"βΉοΈ No dataset file currently tracked for deletion.\n" | |
| log_output += f"--- Memory Clear Complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" | |
| log_output += "β¨ Files and checkpoints have been removed. You can now start a fresh training run." | |
| return log_output, None, gr.Button(visible=False), None | |
| def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()): | |
| """ | |
| Handles the Gradio submission and executes the training script using subprocess. | |
| """ | |
| # 1. Setup: Create output directory | |
| os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True) | |
| # 2. File Handling | |
| if dataset_file is None: | |
| yield "β ERROR: Please upload a file.", None, gr.Button(visible=False) | |
| return | |
| input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file) | |
| if not os.path.exists(input_path): | |
| yield f"β ERROR: Uploaded file not found at {input_path}.", None, gr.Button(visible=False) | |
| return | |
| progress(0.1, desc="Initializing Hybrid Model Training...") | |
| log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n" | |
| log_output += f"π€ Architecture: LayoutLMv3 + BiLSTM + CRF\n" | |
| # 3. Construct the subprocess command | |
| command = [ | |
| sys.executable, | |
| TRAINING_SCRIPT, | |
| "--mode", "train", | |
| "--input", input_path, | |
| "--batch_size", str(batch_size), | |
| "--epochs", str(epochs), | |
| "--lr", str(lr), | |
| "--max_len", str(max_len) | |
| ] | |
| log_output += f"Executing command: {' '.join(command)}\n\n" | |
| yield log_output, None, gr.Button(visible=False) | |
| try: | |
| # 4. Run the training script | |
| process = subprocess.Popen( | |
| command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1 | |
| ) | |
| # Stream logs | |
| for line in iter(process.stdout.readline, ""): | |
| log_output += line | |
| print(line, end='') | |
| yield log_output, None, gr.Button(visible=False) | |
| process.stdout.close() | |
| return_code = process.wait() | |
| # 5. Check completion | |
| if return_code == 0: | |
| log_output += "\n" + "=" * 60 + "\n" | |
| log_output += "β HYBRID TRAINING COMPLETE!\n" | |
| log_output += "=" * 60 + "\n" | |
| if os.path.exists(MODEL_FILE_PATH): | |
| file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024) | |
| log_output += f"\nπ¦ Model file found: {MODEL_FILE_PATH} ({file_size:.2f} MB)" | |
| # Copy for download | |
| import tempfile | |
| temp_dir = tempfile.gettempdir() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| temp_model_path = os.path.join(temp_dir, f"hybrid_model_{timestamp}.pth") | |
| try: | |
| shutil.copy2(MODEL_FILE_PATH, temp_model_path) | |
| download_path = temp_model_path | |
| except Exception as e: | |
| log_output += f"\nβ οΈ Copy failed: {e}, using original path" | |
| download_path = MODEL_FILE_PATH | |
| log_output += f"\n\nβ¬οΈ Click the 'π₯ Download Model' button below." | |
| yield log_output, download_path, gr.Button(visible=True) | |
| return | |
| else: | |
| log_output += f"\nβ Error: Training finished but {MODEL_FILE_PATH} was not found." | |
| yield log_output, None, gr.Button(visible=False) | |
| return | |
| else: | |
| log_output += f"\nβ TRAINING FAILED with return code {return_code}\n" | |
| yield log_output, None, gr.Button(visible=False) | |
| return | |
| except FileNotFoundError: | |
| yield log_output + f"\nβ ERROR: '{TRAINING_SCRIPT}' not found.", None, gr.Button(visible=False) | |
| except Exception as e: | |
| yield log_output + f"\nβ Unexpected Error: {e}", None, gr.Button(visible=False) | |
| # --- Gradio Interface Setup --- | |
| with gr.Blocks(title="Hybrid LayoutLM Training", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 𧬠Hybrid LayoutLMv3 + BiLSTM + CRF Training") | |
| gr.Markdown( | |
| """ | |
| **Architecture:** This app trains a state-of-the-art stack: | |
| 1. **LayoutLMv3** (Visual & Textual Embeddings) | |
| 2. **Bi-LSTM** (Sequence Context Modeling) | |
| 3. **CRF** (Label Consistency Enforcement) | |
| **Instructions:** Upload your Label Studio JSON, set parameters, and train. | |
| **Note:** This model is slower to train than standard LayoutLM but typically achieves higher accuracy on complex layouts. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Dataset") | |
| file_input = gr.File(label="Upload Label Studio JSON", file_types=[".json"]) | |
| gr.Markdown("### βοΈ Hyperparameters") | |
| batch_size_input = gr.Slider(1, 16, value=4, step=1, label="Batch Size") | |
| epochs_input = gr.Slider(1, 10, value=5, step=1, label="Epochs") | |
| lr_input = gr.Number(value=2e-5, label="Learning Rate (Backbone)", info="LSTM/CRF head uses 1e-4") | |
| max_len_input = gr.Slider(128, 512, value=512, step=128, label="Max Seq Len") | |
| train_button = gr.Button("π₯ Start Hybrid Training", variant="primary", size="lg") | |
| check_button = gr.Button("π Check Status / Recover Model", variant="secondary") | |
| clear_button = gr.Button("π§Ή Clear Files", variant="stop") | |
| with gr.Column(scale=2): | |
| log_output = gr.Textbox( | |
| label="Training Logs", lines=25, autoscroll=True, show_copy_button=True, | |
| placeholder="Logs will appear here..." | |
| ) | |
| download_btn = gr.Button("π₯ Download Hybrid Model", variant="primary", size="lg", visible=False) | |
| # State and hidden download component | |
| model_path_state = gr.State(value=None) | |
| model_download = gr.File(label="Download", interactive=False, visible=True) | |
| # Actions | |
| train_button.click( | |
| fn=train_model, | |
| inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input], | |
| outputs=[log_output, model_path_state, download_btn] | |
| ) | |
| check_button.click( | |
| fn=retrieve_model, | |
| inputs=[], | |
| outputs=[log_output, model_path_state, download_btn] | |
| ) | |
| clear_button.click( | |
| fn=clear_memory, | |
| inputs=[file_input], | |
| outputs=[log_output, model_path_state, download_btn, model_download] | |
| ) | |
| download_btn.click( | |
| fn=lambda path: path, | |
| inputs=[model_path_state], | |
| outputs=[model_download] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |