File size: 10,582 Bytes
c7a125e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import gradio as gr
import subprocess
import os
import sys
from datetime import datetime
import shutil

# --- CONFIGURATION UPDATED FOR HYBRID MODEL ---
TRAINING_SCRIPT = "train_hybrid.py"
MODEL_OUTPUT_DIR = "checkpoints"
MODEL_FILE_NAME = "layoutlmv3_bilstm_crf_hybrid.pth"
MODEL_FILE_PATH = os.path.join(MODEL_OUTPUT_DIR, MODEL_FILE_NAME)

# ----------------------------------------------------------------

def retrieve_model():
    """
    Checks for the final model file and prepares it for download.
    Useful for when the training job finishes server-side but the
    client connection has timed out.
    """
    if os.path.exists(MODEL_FILE_PATH):
        file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024)  # Size in MB

        # Copy to a simple location that Gradio can reliably serve
        import tempfile
        temp_dir = tempfile.gettempdir()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        temp_model_path = os.path.join(temp_dir, f"hybrid_model_recovered_{timestamp}.pth")

        try:
            shutil.copy2(MODEL_FILE_PATH, temp_model_path)
            download_path = temp_model_path

            log_output = (
                f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
                f"πŸŽ‰ SUCCESS! The Hybrid LayoutLMv3+BiLSTM+CRF model was found.\n"
                f"πŸ“¦ Model file: {MODEL_FILE_PATH}\n"
                f"πŸ“Š Model size: {file_size:.2f} MB\n"
                f"πŸ”— Download path prepared: {download_path}\n\n"
                f"⬇️ Click the 'πŸ“₯ Download Model' button below to save your model."
            )
            return log_output, download_path, gr.Button(visible=True)

        except Exception as e:
            log_output = (
                f"--- Model Status Check FAILED ---\n"
                f"⚠️ Trained model found, but could not prepare for download: {e}\n"
                f"πŸ“ Original Path: {MODEL_FILE_PATH}. Try again or check Space logs."
            )
            return log_output, None, gr.Button(visible=False)

    else:
        log_output = (
            f"--- Model Status Check: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
            f"❌ Model file not found at {MODEL_FILE_PATH}.\n"
            f"Training may still be running or it failed. Check back later."
        )
        return log_output, None, gr.Button(visible=False)


def clear_memory(dataset_file: gr.File):
    """
    Deletes the model output directory and the uploaded dataset file.
    """
    log_output = f"--- Memory Clear Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"

    # 1. Clear Model Checkpoints Directory
    if os.path.exists(MODEL_OUTPUT_DIR):
        try:
            shutil.rmtree(MODEL_OUTPUT_DIR)
            log_output += f"βœ… Successfully deleted model directory: {MODEL_OUTPUT_DIR}\n"
        except Exception as e:
            log_output += f"❌ ERROR deleting model directory {MODEL_OUTPUT_DIR}: {e}\n"
    else:
        log_output += f"ℹ️ Model directory not found: {MODEL_OUTPUT_DIR} (Nothing to delete)\n"

    # 2. Clear Uploaded Dataset File (Temporary file cleanup)
    if dataset_file is not None:
        input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)
        if os.path.exists(input_path):
            try:
                os.remove(input_path)
                log_output += f"βœ… Successfully deleted uploaded dataset file: {input_path}\n"
            except Exception as e:
                log_output += f"❌ ERROR deleting dataset file {input_path}: {e}\n"
        else:
            log_output += f"ℹ️ Uploaded dataset file not found at {input_path}.\n"
    else:
        log_output += f"ℹ️ No dataset file currently tracked for deletion.\n"

    log_output += f"--- Memory Clear Complete: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
    log_output += "✨ Files and checkpoints have been removed. You can now start a fresh training run."

    return log_output, None, gr.Button(visible=False), None


def train_model(dataset_file: gr.File, batch_size: int, epochs: int, lr: float, max_len: int, progress=gr.Progress()):
    """
    Handles the Gradio submission and executes the training script using subprocess.
    """
    # 1. Setup: Create output directory
    os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

    # 2. File Handling
    if dataset_file is None:
        yield "❌ ERROR: Please upload a file.", None, gr.Button(visible=False)
        return

    input_path = dataset_file.name if hasattr(dataset_file, 'name') else str(dataset_file)

    if not os.path.exists(input_path):
        yield f"❌ ERROR: Uploaded file not found at {input_path}.", None, gr.Button(visible=False)
        return

    progress(0.1, desc="Initializing Hybrid Model Training...")

    log_output = f"--- Training Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ---\n"
    log_output += f"πŸ€– Architecture: LayoutLMv3 + BiLSTM + CRF\n"

    # 3. Construct the subprocess command
    command = [
        sys.executable,
        TRAINING_SCRIPT,
        "--mode", "train",
        "--input", input_path,
        "--batch_size", str(batch_size),
        "--epochs", str(epochs),
        "--lr", str(lr),
        "--max_len", str(max_len)
    ]

    log_output += f"Executing command: {' '.join(command)}\n\n"
    yield log_output, None, gr.Button(visible=False)

    try:
        # 4. Run the training script
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1
        )

        # Stream logs
        for line in iter(process.stdout.readline, ""):
            log_output += line
            print(line, end='')
            yield log_output, None, gr.Button(visible=False)

        process.stdout.close()
        return_code = process.wait()

        # 5. Check completion
        if return_code == 0:
            log_output += "\n" + "=" * 60 + "\n"
            log_output += "βœ… HYBRID TRAINING COMPLETE!\n"
            log_output += "=" * 60 + "\n"

            if os.path.exists(MODEL_FILE_PATH):
                file_size = os.path.getsize(MODEL_FILE_PATH) / (1024 * 1024)
                log_output += f"\nπŸ“¦ Model file found: {MODEL_FILE_PATH} ({file_size:.2f} MB)"

                # Copy for download
                import tempfile
                temp_dir = tempfile.gettempdir()
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                temp_model_path = os.path.join(temp_dir, f"hybrid_model_{timestamp}.pth")

                try:
                    shutil.copy2(MODEL_FILE_PATH, temp_model_path)
                    download_path = temp_model_path
                except Exception as e:
                    log_output += f"\n⚠️ Copy failed: {e}, using original path"
                    download_path = MODEL_FILE_PATH

                log_output += f"\n\n⬇️ Click the 'πŸ“₯ Download Model' button below."
                yield log_output, download_path, gr.Button(visible=True)
                return
            else:
                log_output += f"\n❌ Error: Training finished but {MODEL_FILE_PATH} was not found."
                yield log_output, None, gr.Button(visible=False)
                return
        else:
            log_output += f"\n❌ TRAINING FAILED with return code {return_code}\n"
            yield log_output, None, gr.Button(visible=False)
            return

    except FileNotFoundError:
        yield log_output + f"\n❌ ERROR: '{TRAINING_SCRIPT}' not found.", None, gr.Button(visible=False)
    except Exception as e:
        yield log_output + f"\n❌ Unexpected Error: {e}", None, gr.Button(visible=False)


# --- Gradio Interface Setup ---
with gr.Blocks(title="Hybrid LayoutLM Training", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🧬 Hybrid LayoutLMv3 + BiLSTM + CRF Training")
    gr.Markdown(
        """
        **Architecture:** This app trains a state-of-the-art stack:
        1. **LayoutLMv3** (Visual & Textual Embeddings)
        2. **Bi-LSTM** (Sequence Context Modeling)
        3. **CRF** (Label Consistency Enforcement)
        
        **Instructions:** Upload your Label Studio JSON, set parameters, and train. 
        **Note:** This model is slower to train than standard LayoutLM but typically achieves higher accuracy on complex layouts.
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### πŸ“ Dataset")
            file_input = gr.File(label="Upload Label Studio JSON", file_types=[".json"])

            gr.Markdown("### βš™οΈ Hyperparameters")
            batch_size_input = gr.Slider(1, 16, value=4, step=1, label="Batch Size")
            epochs_input = gr.Slider(1, 10, value=5, step=1, label="Epochs")
            lr_input = gr.Number(value=2e-5, label="Learning Rate (Backbone)", info="LSTM/CRF head uses 1e-4")
            max_len_input = gr.Slider(128, 512, value=512, step=128, label="Max Seq Len")

            train_button = gr.Button("πŸ”₯ Start Hybrid Training", variant="primary", size="lg")
            check_button = gr.Button("πŸ” Check Status / Recover Model", variant="secondary")
            clear_button = gr.Button("🧹 Clear Files", variant="stop")

        with gr.Column(scale=2):
            log_output = gr.Textbox(
                label="Training Logs", lines=25, autoscroll=True, show_copy_button=True,
                placeholder="Logs will appear here..."
            )
            
            download_btn = gr.Button("πŸ“₯ Download Hybrid Model", variant="primary", size="lg", visible=False)
            
            # State and hidden download component
            model_path_state = gr.State(value=None)
            model_download = gr.File(label="Download", interactive=False, visible=True)

    # Actions
    train_button.click(
        fn=train_model,
        inputs=[file_input, batch_size_input, epochs_input, lr_input, max_len_input],
        outputs=[log_output, model_path_state, download_btn]
    )

    check_button.click(
        fn=retrieve_model,
        inputs=[],
        outputs=[log_output, model_path_state, download_btn]
    )

    clear_button.click(
        fn=clear_memory,
        inputs=[file_input],
        outputs=[log_output, model_path_state, download_btn, model_download]
    )

    download_btn.click(
        fn=lambda path: path,
        inputs=[model_path_state],
        outputs=[model_download]
    )

if __name__ == "__main__":
    demo.launch()