Update app.py: auto-start training, file-based logging
Browse files
app.py
CHANGED
|
@@ -2,39 +2,41 @@
|
|
| 2 |
"""
|
| 3 |
Gradio App for EeshaAI/Zeeb Training Space
|
| 4 |
==========================================
|
| 5 |
-
|
| 6 |
-
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
| 10 |
-
import
|
| 11 |
-
import
|
| 12 |
import gradio as gr
|
| 13 |
|
|
|
|
| 14 |
|
| 15 |
-
def run_training():
|
| 16 |
-
"""Run the training pipeline, capturing all output."""
|
| 17 |
-
# Capture all prints and logs
|
| 18 |
-
old_stdout = sys.stdout
|
| 19 |
-
old_stderr = sys.stderr
|
| 20 |
-
sys.stdout = buffer = io.StringIO()
|
| 21 |
-
sys.stderr = buffer
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
try:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
sys.stderr = old_stderr
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
with gr.Blocks(
|
|
@@ -45,36 +47,30 @@ with gr.Blocks(
|
|
| 45 |
gr.Markdown(
|
| 46 |
"""
|
| 47 |
# π¬ Zeeb β Video-LLM Trainer
|
| 48 |
-
Fine-
|
| 49 |
-
Trained model
|
|
|
|
|
|
|
| 50 |
"""
|
| 51 |
)
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
| 54 |
|
| 55 |
logbox = gr.Textbox(
|
| 56 |
label="Training Log",
|
|
|
|
| 57 |
lines=30,
|
| 58 |
max_lines=200,
|
| 59 |
interactive=False,
|
| 60 |
show_copy_button=True,
|
| 61 |
)
|
| 62 |
|
| 63 |
-
|
| 64 |
-
"""
|
| 65 |
-
### What happens when you click "Start Training"?
|
| 66 |
-
1. π¦ Downloads **OLMo 2 1B Instruct** from HuggingFace
|
| 67 |
-
2. π€ Expands vocabulary with **1,024 visual tokens** (`<v_0>` ... `<v_1023>`)
|
| 68 |
-
3. π§ Applies **LoRA r=4** to q_proj & v_proj (minimal memory)
|
| 69 |
-
4. π₯ Trains for **3 epochs** on the tokenized video dataset
|
| 70 |
-
5. π Merges LoRA weights back into the base model
|
| 71 |
-
6. π Pushes the merged model to **EeshaAI/zeeb**
|
| 72 |
-
|
| 73 |
-
β οΈ Training on CPU takes time (~10-30 min depending on dataset size).
|
| 74 |
-
"""
|
| 75 |
-
)
|
| 76 |
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
if __name__ == "__main__":
|
|
|
|
| 2 |
"""
|
| 3 |
Gradio App for EeshaAI/Zeeb Training Space
|
| 4 |
==========================================
|
| 5 |
+
Auto-starts LoRA fine-tuning on Space boot.
|
| 6 |
+
The UI shows real-time training progress from the log file.
|
| 7 |
"""
|
| 8 |
|
| 9 |
import os
|
| 10 |
+
import time
|
| 11 |
+
import threading
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
+
LOG_FILE = "/tmp/training_log.txt"
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
def start_training_background():
|
| 18 |
+
"""Start training in a background thread on Space startup."""
|
| 19 |
+
from train_on_hf_spaces import run_training_to_file
|
| 20 |
+
run_training_to_file(LOG_FILE)
|
| 21 |
|
| 22 |
+
|
| 23 |
+
def get_log():
|
| 24 |
+
"""Read the current training log."""
|
| 25 |
try:
|
| 26 |
+
with open(LOG_FILE, "r") as f:
|
| 27 |
+
return f.read()
|
| 28 |
+
except FileNotFoundError:
|
| 29 |
+
return "β³ Training has not started yet. Please wait..."
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def refresh_log():
|
| 33 |
+
"""Refresh button callback."""
|
| 34 |
+
return get_log()
|
|
|
|
| 35 |
|
| 36 |
+
|
| 37 |
+
# Auto-start training on Space boot
|
| 38 |
+
training_thread = threading.Thread(target=start_training_background, daemon=True)
|
| 39 |
+
training_thread.start()
|
| 40 |
|
| 41 |
|
| 42 |
with gr.Blocks(
|
|
|
|
| 47 |
gr.Markdown(
|
| 48 |
"""
|
| 49 |
# π¬ Zeeb β Video-LLM Trainer
|
| 50 |
+
Fine-tuning **OLMo 2 1B Instruct** with **LoRA (r=4)** to generate video tokens.
|
| 51 |
+
Trained model will be pushed to [EeshaAI/zeeb](https://huggingface.co/EeshaAI/zeeb).
|
| 52 |
+
|
| 53 |
+
Training **starts automatically** when this Space boots.
|
| 54 |
"""
|
| 55 |
)
|
| 56 |
|
| 57 |
+
with gr.Row():
|
| 58 |
+
refresh_btn = gr.Button("π Refresh Log", variant="primary")
|
| 59 |
+
auto_refresh = gr.Checkbox(label="Auto-refresh (every 10s)", value=True)
|
| 60 |
|
| 61 |
logbox = gr.Textbox(
|
| 62 |
label="Training Log",
|
| 63 |
+
value=lambda: get_log(),
|
| 64 |
lines=30,
|
| 65 |
max_lines=200,
|
| 66 |
interactive=False,
|
| 67 |
show_copy_button=True,
|
| 68 |
)
|
| 69 |
|
| 70 |
+
refresh_btn.click(fn=refresh_log, outputs=logbox)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
# Auto-refresh every 10 seconds
|
| 73 |
+
demo.load(fn=refresh_log, outputs=logbox, every=10)
|
| 74 |
|
| 75 |
|
| 76 |
if __name__ == "__main__":
|