File size: 7,058 Bytes
071f38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
import subprocess
import os
import time
import threading
from datetime import datetime

# Global variable to track training status
training_status = {"running": False, "output": "", "progress": 0}

def install_dependencies():
    """Install required packages"""
    try:
        subprocess.run(["pip", "install", "-r", "hf_requirements.txt"], 
                      capture_output=True, text=True, check=True)
        return "βœ… Dependencies installed successfully!"
    except Exception as e:
        return f"❌ Error installing dependencies: {str(e)}"

def extract_data():
    """Extract training data"""
    try:
        if os.path.exists("processed_data.zip"):
            subprocess.run(["unzip", "-o", "processed_data.zip"], 
                          capture_output=True, text=True, check=True)
            return "βœ… Data extracted successfully!"
        else:
            return "❌ processed_data.zip not found! Please upload it first."
    except Exception as e:
        return f"❌ Error extracting data: {str(e)}"

def run_training():
    """Run the training process"""
    global training_status
    
    if training_status["running"]:
        return "⚠️ Training is already running!"
    
    training_status["running"] = True
    training_status["output"] = ""
    training_status["progress"] = 0
    
    try:
        # Install dependencies
        training_status["output"] += "πŸ“¦ Installing dependencies...\n"
        install_result = install_dependencies()
        training_status["output"] += install_result + "\n"
        
        # Extract data
        training_status["output"] += "πŸ“ Extracting data...\n"
        extract_result = extract_data()
        training_status["output"] += extract_result + "\n"
        
        # Start training
        training_status["output"] += "πŸš€ Starting training...\n"
        training_status["output"] += f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
        
        # Run training script
        process = subprocess.Popen(
            ["python", "hf_train.py"],
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
            universal_newlines=True
        )
        
        # Monitor training progress
        for line in process.stdout:
            training_status["output"] += line
            # Update progress based on epoch completion
            if "Epoch" in line and "/50" in line:
                try:
                    epoch_info = line.split("Epoch ")[1].split("/")[0]
                    current_epoch = int(epoch_info)
                    training_status["progress"] = (current_epoch / 50) * 100
                except:
                    pass
        
        process.wait()
        
        if process.returncode == 0:
            training_status["output"] += "\nπŸŽ‰ Training completed successfully!"
        else:
            training_status["output"] += "\n❌ Training failed!"
            
    except Exception as e:
        training_status["output"] += f"\n❌ Error during training: {str(e)}"
    
    finally:
        training_status["running"] = False
        training_status["progress"] = 100

def start_training():
    """Start training in a separate thread"""
    thread = threading.Thread(target=run_training)
    thread.start()
    return "πŸš€ Training started! Check the output below for progress."

def get_training_output():
    """Get current training output"""
    return training_status["output"]

def get_progress():
    """Get training progress"""
    return training_status["progress"]

def check_files():
    """Check if required files are present"""
    files_status = []
    
    # Check training script
    if os.path.exists("hf_train.py"):
        files_status.append("βœ… hf_train.py")
    else:
        files_status.append("❌ hf_train.py (missing)")
    
    # Check requirements
    if os.path.exists("hf_requirements.txt"):
        files_status.append("βœ… hf_requirements.txt")
    else:
        files_status.append("❌ hf_requirements.txt (missing)")
    
    # Check data
    if os.path.exists("processed_data.zip"):
        size = os.path.getsize("processed_data.zip") / (1024 * 1024)  # MB
        files_status.append(f"βœ… processed_data.zip ({size:.1f} MB)")
    else:
        files_status.append("❌ processed_data.zip (missing)")
    
    # Check if data is extracted
    if os.path.exists("processed_data"):
        files_status.append("βœ… processed_data directory")
    else:
        files_status.append("⚠️ processed_data directory (will be created)")
    
    return "\n".join(files_status)

def download_model():
    """Provide download link for trained model"""
    if os.path.exists("best_model.pth"):
        size = os.path.getsize("best_model.pth") / (1024 * 1024)  # MB
        return f"βœ… Model ready for download!\nπŸ“ File: best_model.pth\nπŸ“ Size: {size:.1f} MB\nπŸ’‘ Download from the Files tab on the right."
    else:
        return "❌ No trained model found. Please run training first."

# Create Gradio interface
with gr.Blocks(title="🏠 Floorplan Segmentation Training", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🏠 Floorplan Segmentation Model Training")
    gr.Markdown("Train a deep learning model to segment floorplan images into walls, doors, windows, rooms, and background.")
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## πŸ“‹ File Status")
            file_status = gr.Textbox(label="Required Files", value=check_files, lines=6, interactive=False)
            
            gr.Markdown("## πŸš€ Training Controls")
            start_btn = gr.Button("Start Training", variant="primary", size="lg")
            status_text = gr.Textbox(label="Status", value="Ready to train", interactive=False)
            
            gr.Markdown("## πŸ“Š Progress")
            progress_bar = gr.Slider(minimum=0, maximum=100, value=0, label="Training Progress (%)", interactive=False)
            
            gr.Markdown("## πŸ’Ύ Download Model")
            download_btn = gr.Button("Check Model Status")
            download_status = gr.Textbox(label="Model Status", value="No model trained yet", interactive=False)
        
        with gr.Column(scale=2):
            gr.Markdown("## πŸ“ Training Output")
            output_text = gr.Textbox(label="Training Log", value="Training output will appear here...", lines=20, interactive=False)
    
    # Event handlers
    start_btn.click(
        fn=start_training,
        outputs=status_text
    )
    
    download_btn.click(
        fn=download_model,
        outputs=download_status
    )
    
    # Auto-refresh output and progress
    demo.load(lambda: None, None, None, every=5)  # Refresh every 5 seconds
    
    # Update output and progress
    def update_output():
        return get_training_output(), get_progress()
    
    demo.load(update_output, outputs=[output_text, progress_bar], every=2)

# Launch the app
if __name__ == "__main__":
    demo.launch()