hallelu's picture
Create app.py
071f38f verified
import gradio as gr
import subprocess
import os
import time
import threading
from datetime import datetime
# Global variable to track training status
training_status = {"running": False, "output": "", "progress": 0}
def install_dependencies():
"""Install required packages"""
try:
subprocess.run(["pip", "install", "-r", "hf_requirements.txt"],
capture_output=True, text=True, check=True)
return "βœ… Dependencies installed successfully!"
except Exception as e:
return f"❌ Error installing dependencies: {str(e)}"
def extract_data():
"""Extract training data"""
try:
if os.path.exists("processed_data.zip"):
subprocess.run(["unzip", "-o", "processed_data.zip"],
capture_output=True, text=True, check=True)
return "βœ… Data extracted successfully!"
else:
return "❌ processed_data.zip not found! Please upload it first."
except Exception as e:
return f"❌ Error extracting data: {str(e)}"
def run_training():
"""Run the training process"""
global training_status
if training_status["running"]:
return "⚠️ Training is already running!"
training_status["running"] = True
training_status["output"] = ""
training_status["progress"] = 0
try:
# Install dependencies
training_status["output"] += "πŸ“¦ Installing dependencies...\n"
install_result = install_dependencies()
training_status["output"] += install_result + "\n"
# Extract data
training_status["output"] += "πŸ“ Extracting data...\n"
extract_result = extract_data()
training_status["output"] += extract_result + "\n"
# Start training
training_status["output"] += "πŸš€ Starting training...\n"
training_status["output"] += f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
# Run training script
process = subprocess.Popen(
["python", "hf_train.py"],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True
)
# Monitor training progress
for line in process.stdout:
training_status["output"] += line
# Update progress based on epoch completion
if "Epoch" in line and "/50" in line:
try:
epoch_info = line.split("Epoch ")[1].split("/")[0]
current_epoch = int(epoch_info)
training_status["progress"] = (current_epoch / 50) * 100
except:
pass
process.wait()
if process.returncode == 0:
training_status["output"] += "\nπŸŽ‰ Training completed successfully!"
else:
training_status["output"] += "\n❌ Training failed!"
except Exception as e:
training_status["output"] += f"\n❌ Error during training: {str(e)}"
finally:
training_status["running"] = False
training_status["progress"] = 100
def start_training():
"""Start training in a separate thread"""
thread = threading.Thread(target=run_training)
thread.start()
return "πŸš€ Training started! Check the output below for progress."
def get_training_output():
"""Get current training output"""
return training_status["output"]
def get_progress():
"""Get training progress"""
return training_status["progress"]
def check_files():
"""Check if required files are present"""
files_status = []
# Check training script
if os.path.exists("hf_train.py"):
files_status.append("βœ… hf_train.py")
else:
files_status.append("❌ hf_train.py (missing)")
# Check requirements
if os.path.exists("hf_requirements.txt"):
files_status.append("βœ… hf_requirements.txt")
else:
files_status.append("❌ hf_requirements.txt (missing)")
# Check data
if os.path.exists("processed_data.zip"):
size = os.path.getsize("processed_data.zip") / (1024 * 1024) # MB
files_status.append(f"βœ… processed_data.zip ({size:.1f} MB)")
else:
files_status.append("❌ processed_data.zip (missing)")
# Check if data is extracted
if os.path.exists("processed_data"):
files_status.append("βœ… processed_data directory")
else:
files_status.append("⚠️ processed_data directory (will be created)")
return "\n".join(files_status)
def download_model():
"""Provide download link for trained model"""
if os.path.exists("best_model.pth"):
size = os.path.getsize("best_model.pth") / (1024 * 1024) # MB
return f"βœ… Model ready for download!\nπŸ“ File: best_model.pth\nπŸ“ Size: {size:.1f} MB\nπŸ’‘ Download from the Files tab on the right."
else:
return "❌ No trained model found. Please run training first."
# Create Gradio interface
with gr.Blocks(title="🏠 Floorplan Segmentation Training", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🏠 Floorplan Segmentation Model Training")
gr.Markdown("Train a deep learning model to segment floorplan images into walls, doors, windows, rooms, and background.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## πŸ“‹ File Status")
file_status = gr.Textbox(label="Required Files", value=check_files, lines=6, interactive=False)
gr.Markdown("## πŸš€ Training Controls")
start_btn = gr.Button("Start Training", variant="primary", size="lg")
status_text = gr.Textbox(label="Status", value="Ready to train", interactive=False)
gr.Markdown("## πŸ“Š Progress")
progress_bar = gr.Slider(minimum=0, maximum=100, value=0, label="Training Progress (%)", interactive=False)
gr.Markdown("## πŸ’Ύ Download Model")
download_btn = gr.Button("Check Model Status")
download_status = gr.Textbox(label="Model Status", value="No model trained yet", interactive=False)
with gr.Column(scale=2):
gr.Markdown("## πŸ“ Training Output")
output_text = gr.Textbox(label="Training Log", value="Training output will appear here...", lines=20, interactive=False)
# Event handlers
start_btn.click(
fn=start_training,
outputs=status_text
)
download_btn.click(
fn=download_model,
outputs=download_status
)
# Auto-refresh output and progress
demo.load(lambda: None, None, None, every=5) # Refresh every 5 seconds
# Update output and progress
def update_output():
return get_training_output(), get_progress()
demo.load(update_output, outputs=[output_text, progress_bar], every=2)
# Launch the app
if __name__ == "__main__":
demo.launch()