hallelu commited on
Commit
69f257e
Β·
verified Β·
1 Parent(s): 6d7f802

Upload 4 files

Browse files
Files changed (4) hide show
  1. hf_README.md +142 -0
  2. hf_requirements.txt +7 -0
  3. hf_train.py +392 -0
  4. processed_data.zip +3 -0
hf_README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🏠 Floorplan Segmentation Model Training
2
+
3
+ This repository contains the training code for a floorplan segmentation model that can identify walls, doors, windows, rooms, and background in architectural floorplans.
4
+
5
+ ## 🎯 Model Architecture
6
+
7
+ - **Type**: Ultra Simple U-Net
8
+ - **Input**: RGB floorplan images (224x224)
9
+ - **Output**: 5-class segmentation (Background, Walls, Doors, Windows, Rooms)
10
+ - **Parameters**: ~258K
11
+
12
+ ## πŸ“Š Training Data
13
+
14
+ The model is trained on the Cubicasa5K dataset:
15
+ - **Training**: 4,200 images
16
+ - **Validation**: 400 images
17
+ - **Test**: 400 images
18
+
19
+ ## πŸš€ Quick Start
20
+
21
+ ### 1. Setup Environment
22
+
23
+ ```bash
24
+ pip install -r hf_requirements.txt
25
+ ```
26
+
27
+ ### 2. Prepare Data
28
+
29
+ 1. Upload `processed_data.zip` to this repository
30
+ 2. Extract the data: `unzip processed_data.zip`
31
+
32
+ ### 3. Start Training
33
+
34
+ ```bash
35
+ python hf_train.py
36
+ ```
37
+
38
+ ## βš™οΈ Training Configuration
39
+
40
+ - **Batch Size**: 4
41
+ - **Image Size**: 224x224
42
+ - **Epochs**: 50
43
+ - **Learning Rate**: 1e-4
44
+ - **Optimizer**: Adam
45
+ - **Loss**: CrossEntropyLoss
46
+ - **Scheduler**: CosineAnnealingLR
47
+
48
+ ## πŸ“ˆ Expected Results
49
+
50
+ After training, you should see:
51
+ - **Wall Coverage**: 40-60% (vs previous 20.6%)
52
+ - **Room Detection**: Multiple rooms detected
53
+ - **Door/Window Classification**: Proper distinction from walls
54
+ - **Overall Quality**: Much better than previous attempts
55
+
56
+ ## πŸ’Ύ Model Outputs
57
+
58
+ - `best_model.pth`: Best trained model
59
+ - `checkpoint_epoch_*.pth`: Checkpoints every 10 epochs
60
+ - `training_history.png`: Training progress visualization
61
+
62
+ ## πŸ”§ Usage
63
+
64
+ ### Load Trained Model
65
+
66
+ ```python
67
+ import torch
68
+ from hf_train import UltraSimpleModel
69
+
70
+ # Load model
71
+ model = UltraSimpleModel(n_channels=3, n_classes=5)
72
+ checkpoint = torch.load('best_model.pth', map_location='cpu')
73
+ model.load_state_dict(checkpoint['model_state_dict'])
74
+ model.eval()
75
+ ```
76
+
77
+ ### Predict on New Image
78
+
79
+ ```python
80
+ import cv2
81
+ import torch
82
+
83
+ # Load and preprocess image
84
+ image = cv2.imread('floorplan.png')
85
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
86
+ image = cv2.resize(image, (224, 224))
87
+ image_tensor = torch.from_numpy(image).float().permute(2, 0, 1) / 255.0
88
+ image_tensor = image_tensor.unsqueeze(0)
89
+
90
+ # Predict
91
+ with torch.no_grad():
92
+ output = model(image_tensor)
93
+ prediction = torch.argmax(output, dim=1).squeeze(0).numpy()
94
+ ```
95
+
96
+ ## πŸ“Š Class Mapping
97
+
98
+ - **0**: Background (Black)
99
+ - **1**: Walls (Red)
100
+ - **2**: Doors (Green)
101
+ - **3**: Windows (Blue)
102
+ - **4**: Rooms (Yellow)
103
+
104
+ ## 🎯 Performance Metrics
105
+
106
+ - **Loss**: CrossEntropyLoss
107
+ - **Validation**: Every epoch
108
+ - **Checkpointing**: Every 10 epochs
109
+ - **Best Model**: Saved when validation loss improves
110
+
111
+ ## πŸ” Troubleshooting
112
+
113
+ ### Common Issues
114
+
115
+ 1. **CUDA Out of Memory**: Reduce batch size to 2
116
+ 2. **Data Not Found**: Ensure `processed_data.zip` is uploaded
117
+ 3. **Slow Training**: Check GPU availability
118
+
119
+ ### Performance Tips
120
+
121
+ - Use GPU for faster training
122
+ - Monitor GPU memory usage
123
+ - Clear cache periodically during training
124
+
125
+ ## πŸ“ž Support
126
+
127
+ If you encounter issues:
128
+ 1. Check the training logs
129
+ 2. Verify data format
130
+ 3. Ensure all dependencies are installed
131
+
132
+ ## πŸ† Results
133
+
134
+ This model should significantly improve upon the previous poor performance:
135
+ - Better wall detection
136
+ - Proper room segmentation
137
+ - Accurate door/window classification
138
+ - Overall higher quality results
139
+
140
+ ---
141
+
142
+ **Happy Training! πŸš€**
hf_requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ opencv-python>=4.8.0
4
+ numpy>=1.24.0
5
+ matplotlib>=3.7.0
6
+ tqdm>=4.65.0
7
+ pillow>=10.0.0
hf_train.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 🏠 Floorplan Segmentation Training on Hugging Face
4
+ Complete training script with proper logging and error handling
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import cv2
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+ import os
15
+ import matplotlib.pyplot as plt
16
+ import time
17
+ import gc
18
+ from datetime import datetime
19
+
20
+ print("πŸš€ Starting Floorplan Segmentation Training on Hugging Face...")
21
+ print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
22
+
23
+ # ============================================================================
24
+ # 1. MODEL ARCHITECTURE
25
+ # ============================================================================
26
+
27
+ class UltraSimpleModel(nn.Module):
28
+ def __init__(self, n_channels=3, n_classes=5):
29
+ super().__init__()
30
+
31
+ self.encoder = nn.Sequential(
32
+ nn.Conv2d(n_channels, 32, 3, padding=1),
33
+ nn.ReLU(),
34
+ nn.Conv2d(32, 32, 3, padding=1),
35
+ nn.ReLU(),
36
+ nn.MaxPool2d(2),
37
+
38
+ nn.Conv2d(32, 64, 3, padding=1),
39
+ nn.ReLU(),
40
+ nn.Conv2d(64, 64, 3, padding=1),
41
+ nn.ReLU(),
42
+ nn.MaxPool2d(2),
43
+
44
+ nn.Conv2d(64, 128, 3, padding=1),
45
+ nn.ReLU(),
46
+ nn.Conv2d(128, 128, 3, padding=1),
47
+ nn.ReLU(),
48
+ nn.MaxPool2d(2),
49
+ )
50
+
51
+ self.decoder = nn.Sequential(
52
+ nn.ConvTranspose2d(128, 64, 2, stride=2),
53
+ nn.ReLU(),
54
+ nn.Conv2d(64, 64, 3, padding=1),
55
+ nn.ReLU(),
56
+
57
+ nn.ConvTranspose2d(64, 32, 2, stride=2),
58
+ nn.ReLU(),
59
+ nn.Conv2d(32, 32, 3, padding=1),
60
+ nn.ReLU(),
61
+
62
+ nn.ConvTranspose2d(32, 16, 2, stride=2),
63
+ nn.ReLU(),
64
+ nn.Conv2d(16, n_classes, 1),
65
+ )
66
+
67
+ def forward(self, x):
68
+ x = self.encoder(x)
69
+ x = self.decoder(x)
70
+ return x
71
+
72
+ # ============================================================================
73
+ # 2. DATASET CLASS
74
+ # ============================================================================
75
+
76
+ class SimpleDataset(Dataset):
77
+ def __init__(self, data_dir, image_size=224):
78
+ self.data_dir = data_dir
79
+ self.image_size = image_size
80
+
81
+ # Get image files
82
+ self.image_files = []
83
+ for file in os.listdir(data_dir):
84
+ if file.endswith('_image.png'):
85
+ mask_file = file.replace('_image.png', '_mask.png')
86
+ if os.path.exists(os.path.join(data_dir, mask_file)):
87
+ self.image_files.append(file)
88
+
89
+ print(f"πŸ“Š Found {len(self.image_files)} image-mask pairs in {data_dir}")
90
+
91
+ def __len__(self):
92
+ return len(self.image_files)
93
+
94
+ def __getitem__(self, idx):
95
+ # Load image
96
+ image_file = self.image_files[idx]
97
+ image_path = os.path.join(self.data_dir, image_file)
98
+ mask_path = os.path.join(self.data_dir, image_file.replace('_image.png', '_mask.png'))
99
+
100
+ # Load and preprocess
101
+ image = cv2.imread(image_path)
102
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103
+ image = cv2.resize(image, (self.image_size, self.image_size))
104
+
105
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
106
+ mask = cv2.resize(mask, (self.image_size, self.image_size))
107
+
108
+ # Convert to tensors
109
+ image = torch.from_numpy(image).float().permute(2, 0, 1) / 255.0
110
+ mask = torch.from_numpy(mask).long()
111
+
112
+ return image, mask
113
+
114
+ # ============================================================================
115
+ # 3. TRAINING SETUP
116
+ # ============================================================================
117
+
118
+ def setup_training():
119
+ """Setup training environment"""
120
+ print("πŸ”§ Setting up training environment...")
121
+
122
+ # Clear GPU memory
123
+ torch.cuda.empty_cache()
124
+ gc.collect()
125
+
126
+ # Check device
127
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
128
+ print(f"βœ… Using device: {device}")
129
+
130
+ if torch.cuda.is_available():
131
+ print(f"βœ… GPU: {torch.cuda.get_device_name(0)}")
132
+ print(f"βœ… GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
133
+
134
+ # Training parameters
135
+ BATCH_SIZE = 4
136
+ IMAGE_SIZE = 224
137
+ EPOCHS = 50
138
+ LEARNING_RATE = 1e-4
139
+
140
+ print(f"πŸ”„ Training Configuration:")
141
+ print(f" Batch size: {BATCH_SIZE}")
142
+ print(f" Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
143
+ print(f" Epochs: {EPOCHS}")
144
+ print(f" Learning rate: {LEARNING_RATE}")
145
+
146
+ return device, BATCH_SIZE, IMAGE_SIZE, EPOCHS, LEARNING_RATE
147
+
148
+ def create_data_loaders(BATCH_SIZE, IMAGE_SIZE):
149
+ """Create training and validation data loaders"""
150
+ print("πŸ“Š Creating data loaders...")
151
+
152
+ # Check if data exists
153
+ if not os.path.exists('processed_data'):
154
+ print("❌ processed_data directory not found!")
155
+ print("πŸ’‘ Please upload processed_data.zip to this repository")
156
+ return None, None
157
+
158
+ # Create datasets
159
+ train_dataset = SimpleDataset('processed_data/train', image_size=IMAGE_SIZE)
160
+ val_dataset = SimpleDataset('processed_data/val', image_size=IMAGE_SIZE)
161
+
162
+ # Create loaders
163
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
164
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
165
+
166
+ print(f"βœ… Data loaders created!")
167
+ print(f" Training batches: {len(train_loader)}")
168
+ print(f" Validation batches: {len(val_loader)}")
169
+
170
+ return train_loader, val_loader
171
+
172
+ # ============================================================================
173
+ # 4. TRAINING LOOP
174
+ # ============================================================================
175
+
176
+ def train_model(model, train_loader, val_loader, device, EPOCHS, LEARNING_RATE):
177
+ """Main training loop"""
178
+ print(f"\n🎯 Starting training for {EPOCHS} epochs...")
179
+
180
+ # Setup training components
181
+ criterion = nn.CrossEntropyLoss()
182
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
183
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
184
+
185
+ # Training history
186
+ history = {
187
+ 'train_loss': [],
188
+ 'val_loss': [],
189
+ 'learning_rate': []
190
+ }
191
+
192
+ best_val_loss = float('inf')
193
+ start_time = time.time()
194
+
195
+ for epoch in range(EPOCHS):
196
+ epoch_start_time = time.time()
197
+ print(f"\nπŸ“… Epoch {epoch+1}/{EPOCHS}")
198
+
199
+ # Training phase
200
+ model.train()
201
+ train_loss = 0.0
202
+
203
+ train_pbar = tqdm(train_loader, desc="Training")
204
+ for batch_idx, (images, masks) in enumerate(train_pbar):
205
+ images = images.to(device)
206
+ masks = masks.to(device)
207
+
208
+ # Forward pass
209
+ optimizer.zero_grad()
210
+ outputs = model(images)
211
+ loss = criterion(outputs, masks)
212
+
213
+ # Backward pass
214
+ loss.backward()
215
+ optimizer.step()
216
+
217
+ # Update metrics
218
+ train_loss += loss.item()
219
+
220
+ # Update progress bar
221
+ train_pbar.set_postfix({
222
+ 'Loss': f'{loss.item():.4f}',
223
+ 'GPU': f'{torch.cuda.memory_allocated()/1e9:.1f}GB'
224
+ })
225
+
226
+ # Clear cache periodically
227
+ if batch_idx % 100 == 0:
228
+ torch.cuda.empty_cache()
229
+
230
+ avg_train_loss = train_loss / len(train_loader)
231
+
232
+ # Validation phase
233
+ model.eval()
234
+ val_loss = 0.0
235
+
236
+ with torch.no_grad():
237
+ val_pbar = tqdm(val_loader, desc="Validation")
238
+ for batch_idx, (images, masks) in enumerate(val_pbar):
239
+ images = images.to(device)
240
+ masks = masks.to(device)
241
+
242
+ outputs = model(images)
243
+ loss = criterion(outputs, masks)
244
+ val_loss += loss.item()
245
+
246
+ val_pbar.set_postfix({
247
+ 'Loss': f'{loss.item():.4f}'
248
+ })
249
+
250
+ avg_val_loss = val_loss / len(val_loader)
251
+
252
+ # Update learning rate
253
+ scheduler.step()
254
+ current_lr = optimizer.param_groups[0]['lr']
255
+
256
+ # Update history
257
+ history['train_loss'].append(avg_train_loss)
258
+ history['val_loss'].append(avg_val_loss)
259
+ history['learning_rate'].append(current_lr)
260
+
261
+ # Calculate epoch time
262
+ epoch_time = time.time() - epoch_start_time
263
+
264
+ # Print results
265
+ print(f"πŸ“Š Train Loss: {avg_train_loss:.4f}")
266
+ print(f" Val Loss: {avg_val_loss:.4f}")
267
+ print(f"πŸ“Š Learning Rate: {current_lr:.6f}")
268
+ print(f" GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")
269
+ print(f"⏱️ Epoch time: {epoch_time:.1f}s")
270
+
271
+ # Save best model
272
+ if avg_val_loss < best_val_loss:
273
+ best_val_loss = avg_val_loss
274
+ torch.save({
275
+ 'epoch': epoch,
276
+ 'model_state_dict': model.state_dict(),
277
+ 'optimizer_state_dict': optimizer.state_dict(),
278
+ 'scheduler_state_dict': scheduler.state_dict(),
279
+ 'best_val_loss': best_val_loss,
280
+ 'history': history,
281
+ 'config': {
282
+ 'model_type': 'ultra_simple',
283
+ 'n_channels': 3,
284
+ 'n_classes': 5,
285
+ 'image_size': 224,
286
+ 'batch_size': 4
287
+ }
288
+ }, 'best_model.pth')
289
+ print(f"βœ… New best model saved! Loss: {best_val_loss:.4f}")
290
+
291
+ # Save checkpoint every 10 epochs
292
+ if (epoch + 1) % 10 == 0:
293
+ torch.save({
294
+ 'epoch': epoch,
295
+ 'model_state_dict': model.state_dict(),
296
+ 'optimizer_state_dict': optimizer.state_dict(),
297
+ 'scheduler_state_dict': scheduler.state_dict(),
298
+ 'best_val_loss': best_val_loss,
299
+ 'history': history
300
+ }, f'checkpoint_epoch_{epoch+1}.pth')
301
+ print(f"πŸ’Ύ Checkpoint saved: checkpoint_epoch_{epoch+1}.pth")
302
+
303
+ # Clear cache after each epoch
304
+ torch.cuda.empty_cache()
305
+
306
+ # Progress update
307
+ if (epoch + 1) % 5 == 0:
308
+ elapsed_time = time.time() - start_time
309
+ avg_epoch_time = elapsed_time / (epoch + 1)
310
+ remaining_epochs = EPOCHS - (epoch + 1)
311
+ estimated_time = remaining_epochs * avg_epoch_time
312
+
313
+ print(f"\nπŸ“ˆ Progress Update:")
314
+ print(f" Epochs completed: {epoch+1}/{EPOCHS}")
315
+ print(f" Best validation loss: {best_val_loss:.4f}")
316
+ print(f" Average epoch time: {avg_epoch_time:.1f}s")
317
+ print(f" Estimated time remaining: {estimated_time/60:.1f} minutes")
318
+
319
+ # Training complete
320
+ total_time = time.time() - start_time
321
+ print(f"\nπŸŽ‰ Training completed!")
322
+ print(f"⏱️ Total time: {total_time/3600:.1f} hours")
323
+ print(f" Best validation loss: {best_val_loss:.4f}")
324
+
325
+ return history
326
+
327
+ # ============================================================================
328
+ # 5. VISUALIZATION
329
+ # ============================================================================
330
+
331
+ def plot_training_history(history):
332
+ """Plot training history"""
333
+ if len(history['train_loss']) > 0:
334
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
335
+
336
+ # Plot losses
337
+ ax1.plot(history['train_loss'], label='Train Loss')
338
+ ax1.plot(history['val_loss'], label='Val Loss')
339
+ ax1.set_title('Training and Validation Loss')
340
+ ax1.set_xlabel('Epoch')
341
+ ax1.set_ylabel('Loss')
342
+ ax1.legend()
343
+ ax1.grid(True)
344
+
345
+ # Plot learning rate
346
+ ax2.plot(history['learning_rate'], label='Learning Rate')
347
+ ax2.set_title('Learning Rate Schedule')
348
+ ax2.set_xlabel('Epoch')
349
+ ax2.set_ylabel('Learning Rate')
350
+ ax2.legend()
351
+ ax2.grid(True)
352
+
353
+ plt.tight_layout()
354
+ plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
355
+ print("πŸ“Š Training history plotted and saved as 'training_history.png'")
356
+
357
+ # ============================================================================
358
+ # 6. MAIN FUNCTION
359
+ # ============================================================================
360
+
361
+ def main():
362
+ """Main training function"""
363
+ try:
364
+ # Setup
365
+ device, BATCH_SIZE, IMAGE_SIZE, EPOCHS, LEARNING_RATE = setup_training()
366
+
367
+ # Create data loaders
368
+ train_loader, val_loader = create_data_loaders(BATCH_SIZE, IMAGE_SIZE)
369
+ if train_loader is None:
370
+ return
371
+
372
+ # Create model
373
+ model = UltraSimpleModel(n_channels=3, n_classes=5).to(device)
374
+ print(f"βœ… Model created! Parameters: {sum(p.numel() for p in model.parameters()):,}")
375
+
376
+ # Train model
377
+ history = train_model(model, train_loader, val_loader, device, EPOCHS, LEARNING_RATE)
378
+
379
+ # Plot results
380
+ plot_training_history(history)
381
+
382
+ print("\nβœ… Training completed successfully!")
383
+ print("πŸ’Ύ Best model saved as 'best_model.pth'")
384
+ print("πŸ“Š Training history saved as 'training_history.png'")
385
+
386
+ except Exception as e:
387
+ print(f"❌ Training failed with error: {e}")
388
+ import traceback
389
+ traceback.print_exc()
390
+
391
+ if __name__ == "__main__":
392
+ main()
processed_data.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59f98c394089de9be227fd222444a1f36242c275947f597ec7f9f925eba4c42a
3
+ size 994235873