Spaces:
Paused
Paused
| import os | |
| import sys | |
| import json | |
| import logging | |
| import gradio as gr | |
| from pathlib import Path | |
| import subprocess | |
| import time | |
| from datetime import datetime | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration paths | |
| CONFIG_DIR = "." | |
| TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json") | |
| HARDWARE_CONFIG = os.path.join(CONFIG_DIR, "hardware_config.json") | |
| DATASET_CONFIG = os.path.join(CONFIG_DIR, "dataset_config.json") | |
| def load_config(config_path): | |
| """Load configuration from JSON file.""" | |
| try: | |
| if os.path.exists(config_path): | |
| with open(config_path, 'r') as f: | |
| return json.load(f) | |
| else: | |
| logger.error(f"Config file not found: {config_path}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error loading config: {str(e)}") | |
| return None | |
| def display_config(): | |
| """Display current training configuration.""" | |
| transformers_config = load_config(TRANSFORMERS_CONFIG) | |
| hardware_config = load_config(HARDWARE_CONFIG) | |
| dataset_config = load_config(DATASET_CONFIG) | |
| if not all([transformers_config, hardware_config, dataset_config]): | |
| return "Error loading configuration files." | |
| # Extract key parameters | |
| model_name = transformers_config.get("model", {}).get("name", "") | |
| dataset_name = dataset_config.get("dataset", {}).get("name", "") | |
| batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 0) | |
| gradient_accum = transformers_config.get("training", {}).get("gradient_accumulation_steps", 0) | |
| lr = transformers_config.get("training", {}).get("learning_rate", 0) | |
| epochs = transformers_config.get("training", {}).get("num_train_epochs", 0) | |
| gpu_count = hardware_config.get("specs", {}).get("gpu_count", 0) | |
| gpu_type = hardware_config.get("specs", {}).get("gpu_type", "") | |
| config_info = f""" | |
| ## Current Training Configuration | |
| **Model**: {model_name} | |
| **Dataset**: {dataset_name} | |
| **Training Parameters**: | |
| - Learning Rate: {lr} | |
| - Epochs: {epochs} | |
| - Batch Size/GPU: {batch_size} | |
| - Gradient Accumulation: {gradient_accum} | |
| - Effective Batch Size: {batch_size * gradient_accum * gpu_count} | |
| **Hardware**: | |
| - GPUs: {gpu_count}x {gpu_type} | |
| - Flash Attention: {hardware_config.get("memory_optimization", {}).get("use_flash_attention", False)} | |
| - Gradient Checkpointing: {hardware_config.get("memory_optimization", {}).get("use_gradient_checkpointing", False)} | |
| **Pre-quantized 4-bit Training**: Enabled | |
| """ | |
| return config_info | |
| def start_training(): | |
| """Start the training process.""" | |
| try: | |
| # Check if already running | |
| if os.path.exists("training.pid"): | |
| with open("training.pid", "r") as f: | |
| pid = f.read().strip() | |
| try: | |
| # Check if process is still running | |
| os.kill(int(pid), 0) | |
| return f"Training is already running with PID {pid}" | |
| except OSError: | |
| # Process not running, remove stale PID file | |
| os.remove("training.pid") | |
| # Start training in background | |
| cmd = "python run_transformers_training.py" | |
| process = subprocess.Popen( | |
| cmd, | |
| shell=True, | |
| stdout=open('training.log', 'a'), | |
| stderr=subprocess.STDOUT | |
| ) | |
| # Save PID | |
| with open("training.pid", "w") as f: | |
| f.write(str(process.pid)) | |
| # Log start time | |
| with open("training_history.log", "a") as f: | |
| f.write(f"{datetime.now().isoformat()}: Training started (PID: {process.pid})\n") | |
| return f"Training started with PID {process.pid}. Check status for updates." | |
| except Exception as e: | |
| return f"Error starting training: {str(e)}" | |
| def check_training_status(): | |
| """Check the status of training.""" | |
| try: | |
| # Check if training is running | |
| if os.path.exists("training.pid"): | |
| with open("training.pid", "r") as f: | |
| pid = f.read().strip() | |
| try: | |
| # Check if process is still running | |
| os.kill(int(pid), 0) | |
| status = f"Training is running with PID {pid}" | |
| except OSError: | |
| status = "Training process has stopped" | |
| os.remove("training.pid") | |
| else: | |
| status = "No training process is currently running" | |
| # Get last lines from training log | |
| log_content = "No training log available" | |
| if os.path.exists("training.log"): | |
| with open("training.log", "r") as f: | |
| lines = f.readlines() | |
| log_content = "".join(lines[-20:]) if lines else "Log file is empty" | |
| return f"{status}\n\n**Recent Log:**\n```\n{log_content}\n```" | |
| except Exception as e: | |
| return f"Error checking status: {str(e)}" | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Phi-4 Unsloth Training", theme=gr.themes.Soft(primary_hue="blue")) as app: | |
| gr.Markdown("# Phi-4 Unsloth 4-bit Training Interface") | |
| with gr.Tabs(): | |
| with gr.TabItem("Configuration"): | |
| config_output = gr.Markdown(display_config()) | |
| refresh_btn = gr.Button("Refresh Configuration") | |
| refresh_btn.click(fn=display_config, outputs=config_output) | |
| with gr.TabItem("Training Control"): | |
| gr.Markdown("## Training Management") | |
| with gr.Row(): | |
| start_btn = gr.Button("Start Training", variant="primary") | |
| check_btn = gr.Button("Check Status") | |
| status_output = gr.Markdown("Click 'Check Status' to see training progress") | |
| start_btn.click(fn=start_training, outputs=status_output) | |
| check_btn.click(fn=check_training_status, outputs=status_output) | |
| # Auto-refresh status | |
| gr.HTML(''' | |
| <script> | |
| let intervalId; | |
| document.addEventListener('DOMContentLoaded', function() { | |
| // Find the "Check Status" button | |
| const buttons = Array.from(document.querySelectorAll('button')); | |
| const checkBtn = buttons.find(btn => btn.textContent.includes('Check Status')); | |
| // Set up interval to click the button every 30 seconds | |
| if (checkBtn) { | |
| intervalId = setInterval(() => { | |
| checkBtn.click(); | |
| }, 30000); | |
| } | |
| }); | |
| // Clean up on tab/window close | |
| window.addEventListener('beforeunload', function() { | |
| clearInterval(intervalId); | |
| }); | |
| </script> | |
| ''') | |
| with gr.TabItem("Help"): | |
| gr.Markdown(""" | |
| ## Phi-4 Unsloth Training Help | |
| This interface allows you to manage training of the Phi-4 model with Unsloth 4-bit optimizations. | |
| ### Installation | |
| Before starting training, ensure all dependencies are installed: | |
| ```bash | |
| pip install -r requirements.txt | |
| ``` | |
| Critical packages: | |
| - unsloth (>=2024.3) | |
| - peft (>=0.9.0) | |
| - transformers (>=4.36.0) | |
| ### Quick Start | |
| 1. Review the configuration in the Configuration tab | |
| 2. Click "Start Training" to begin the process | |
| 3. Use "Check Status" to monitor progress | |
| ### Notes | |
| - Training uses the pre-quantized model `unsloth/phi-4-unsloth-bnb-4bit` | |
| - The process maintains paper order and handles metadata appropriately | |
| - Training progress will be regularly saved to HuggingFace Hub | |
| ### Troubleshooting | |
| If training stops unexpectedly: | |
| - Check the logs for out-of-memory errors | |
| - Verify the VRAM usage on each GPU | |
| - Check for CUDA version compatibility | |
| - If you see "Unsloth not available" error, run: `pip install unsloth>=2024.3 peft>=0.9.0` | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch() | |