Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms | |
| import numpy as np | |
| import gzip | |
| import os | |
| from pathlib import Path | |
| from datetime import datetime | |
| import urllib.request | |
| import shutil | |
| from tqdm import tqdm | |
| import asyncio | |
| from fastapi import WebSocket | |
| import json | |
| from scripts.model import Net | |
| class TrainingConfig: | |
| def __init__(self, params_dict): | |
| self.block1 = params_dict['block1'] | |
| self.block2 = params_dict['block2'] | |
| self.block3 = params_dict['block3'] | |
| self.optimizer = params_dict['optimizer'] | |
| self.batch_size = params_dict['batch_size'] | |
| self.epochs = params_dict['epochs'] | |
| def generate_model_filename(config, model_type="single"): | |
| """Generate a filename based on model configuration | |
| model_type can be "single", "model_1", or "model_2" | |
| """ | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| arch = f"{config.block1}_{config.block2}_{config.block3}" | |
| opt = config.optimizer.lower() | |
| batch = str(config.batch_size) | |
| return f"{model_type}_arch_{arch}_opt_{opt}_batch_{batch}_{timestamp}.pth" | |
| def download_and_extract_mnist_data(): | |
| """Download and extract MNIST dataset from a reliable mirror""" | |
| base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" | |
| files = { | |
| "train_images": "train-images-idx3-ubyte.gz", | |
| "train_labels": "train-labels-idx1-ubyte.gz", | |
| "test_images": "t10k-images-idx3-ubyte.gz", | |
| "test_labels": "t10k-labels-idx1-ubyte.gz" | |
| } | |
| data_dir = Path("data/MNIST/raw") | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| for file_name in files.values(): | |
| gz_file_path = data_dir / file_name | |
| extracted_file_path = data_dir / file_name.replace('.gz', '') | |
| # If the extracted file exists, skip downloading | |
| if extracted_file_path.exists(): | |
| print(f"{extracted_file_path} already exists, skipping download.") | |
| continue | |
| # Download the file | |
| print(f"Downloading {file_name}...") | |
| url = base_url + file_name | |
| try: | |
| urllib.request.urlretrieve(url, gz_file_path) | |
| print(f"Successfully downloaded {file_name}") | |
| except Exception as e: | |
| print(f"Failed to download {file_name}: {e}") | |
| raise Exception(f"Could not download {file_name}") | |
| # Extract the files | |
| try: | |
| print(f"Extracting {file_name}...") | |
| with gzip.open(gz_file_path, 'rb') as f_in: | |
| with open(extracted_file_path, 'wb') as f_out: | |
| shutil.copyfileobj(f_in, f_out) | |
| print(f"Successfully extracted {file_name}") | |
| except Exception as e: | |
| print(f"Failed to extract {file_name}: {e}") | |
| raise Exception(f"Could not extract {file_name}") | |
| def load_mnist_images(filename): | |
| with open(filename, 'rb') as f: | |
| data = np.frombuffer(f.read(), np.uint8, offset=16) | |
| return data.reshape(-1, 1, 28, 28).astype(np.float32) / 255.0 | |
| def load_mnist_labels(filename): | |
| with open(filename, 'rb') as f: | |
| return np.frombuffer(f.read(), np.uint8, offset=8) | |
| class CustomMNISTDataset(Dataset): | |
| def __init__(self, images_path, labels_path, transform=None): | |
| self.images = load_mnist_images(images_path) | |
| self.labels = load_mnist_labels(labels_path) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, idx): | |
| image = torch.FloatTensor(self.images[idx]) | |
| label = int(self.labels[idx]) | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def validate(model, test_loader, criterion, device): | |
| """Modified validate function to handle validation properly""" | |
| model.eval() | |
| val_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| num_batches = 0 | |
| with torch.no_grad(): # Important: no gradient computation in validation | |
| for data, target in test_loader: | |
| data, target = data.to(device), target.to(device) | |
| output = model(data) | |
| val_loss += criterion(output, target).item() # Don't scale by batch size | |
| _, predicted = output.max(1) | |
| total += target.size(0) | |
| correct += predicted.eq(target).sum().item() | |
| num_batches += 1 | |
| # Average the loss by number of batches and accuracy by total samples | |
| val_loss = val_loss / num_batches # Average loss across batches | |
| val_acc = 100. * correct / total | |
| return val_loss, val_acc | |
| async def train(model, config, websocket=None, model_type="single"): | |
| print("\nStarting training...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| model = model.to(device) | |
| # Create data directory if it doesn't exist | |
| data_dir = Path("data") | |
| data_dir.mkdir(exist_ok=True) | |
| # Ensure data is downloaded and extracted | |
| print("Preparing dataset...") | |
| download_and_extract_mnist_data() | |
| # Paths to the extracted files | |
| train_images_path = "data/MNIST/raw/train-images-idx3-ubyte" | |
| train_labels_path = "data/MNIST/raw/train-labels-idx1-ubyte" | |
| test_images_path = "data/MNIST/raw/t10k-images-idx3-ubyte" | |
| test_labels_path = "data/MNIST/raw/t10k-labels-idx1-ubyte" | |
| # Data loading | |
| transform = transforms.Compose([ | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| train_dataset = CustomMNISTDataset(train_images_path, train_labels_path, transform=transform) | |
| test_dataset = CustomMNISTDataset(test_images_path, test_labels_path, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) | |
| print(f"Dataset loaded. Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}") | |
| print("\nTraining Configuration:") | |
| print(f"Epochs: {config.epochs}") | |
| print(f"Optimizer: {config.optimizer}") | |
| print(f"Batch Size: {config.batch_size}") | |
| print(f"Network Architecture: {config.block1}-{config.block2}-{config.block3}") | |
| # Print model parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"\nModel Parameters:") | |
| print(f"Total parameters: {total_params:,}") | |
| print(f"Trainable parameters: {trainable_params:,}") | |
| print("\nStarting training loop...") | |
| best_val_acc = 0 | |
| criterion = nn.CrossEntropyLoss() | |
| # Initialize optimizer based on config | |
| if config.optimizer.lower() == 'adam': | |
| optimizer = optim.Adam(model.parameters()) | |
| else: | |
| optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) | |
| # Create models directory if it doesn't exist | |
| models_dir = Path("scripts/training/models") | |
| models_dir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| for epoch in range(config.epochs): | |
| model.train() | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| progress_bar = tqdm( | |
| train_loader, | |
| desc=f"Epoch {epoch+1}/{config.epochs}", | |
| unit='batch', | |
| leave=True | |
| ) | |
| for batch_idx, (data, target) in enumerate(progress_bar): | |
| data, target = data.to(device), target.to(device) | |
| optimizer.zero_grad() | |
| output = model(data) | |
| loss = criterion(output, target) | |
| loss.backward() | |
| optimizer.step() | |
| # Calculate batch accuracy | |
| pred = output.argmax(dim=1, keepdim=True) | |
| correct += pred.eq(target.view_as(pred)).sum().item() | |
| total += target.size(0) | |
| total_loss += loss.item() | |
| # Calculate current metrics | |
| current_loss = total_loss / (batch_idx + 1) | |
| current_acc = 100. * correct / total | |
| # Send training update through websocket | |
| if websocket: | |
| try: | |
| step = batch_idx + epoch * len(train_loader) | |
| await websocket.send_json({ | |
| 'type': 'training_update', | |
| 'data': { | |
| 'step': step, | |
| 'train_loss': current_loss, | |
| 'train_acc': current_acc, | |
| 'epoch': epoch | |
| } | |
| }) | |
| except Exception as e: | |
| print(f"Error sending websocket update: {e}") | |
| # Validation phase | |
| model.eval() | |
| val_loss = 0 | |
| val_correct = 0 | |
| val_total = 0 | |
| print("\nRunning validation...") | |
| with torch.no_grad(): | |
| for data, target in test_loader: | |
| data, target = data.to(device), target.to(device) | |
| output = model(data) | |
| val_loss += criterion(output, target).item() | |
| pred = output.argmax(dim=1, keepdim=True) | |
| val_correct += pred.eq(target.view_as(pred)).sum().item() | |
| val_total += target.size(0) | |
| val_loss /= len(test_loader) | |
| val_acc = 100. * val_correct / val_total | |
| # Print epoch results | |
| print(f"\nEpoch {epoch+1}/{config.epochs} Results:") | |
| print(f"Training Loss: {current_loss:.4f} | Training Accuracy: {current_acc:.2f}%") | |
| print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%") | |
| # Send validation update through websocket | |
| if websocket: | |
| try: | |
| await websocket.send_json({ | |
| 'type': 'validation_update', | |
| 'data': { | |
| 'step': (epoch + 1) * len(train_loader), | |
| 'val_loss': val_loss, | |
| 'val_acc': val_acc | |
| } | |
| }) | |
| except Exception as e: | |
| print(f"Error sending websocket update: {e}") | |
| # Save best model with configuration in filename | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| print(f"\nNew best validation accuracy: {val_acc:.2f}%") | |
| # Generate filename with configuration | |
| model_filename = generate_model_filename(config, model_type) | |
| model_path = models_dir / model_filename | |
| print(f"Saving model as: {model_filename}") | |
| torch.save(model.state_dict(), model_path) | |
| except Exception as e: | |
| print(f"\nError during training: {e}") | |
| if websocket: | |
| await websocket.send_json({ | |
| 'type': 'training_error', | |
| 'data': { | |
| 'message': str(e) | |
| } | |
| }) | |
| raise e | |
| print("\nTraining completed!") | |
| print(f"Best validation accuracy: {best_val_acc:.2f}%") | |
| if websocket: | |
| await websocket.send_json({ | |
| 'type': 'training_complete', | |
| 'data': { | |
| 'message': 'Training completed successfully!', | |
| 'best_val_acc': best_val_acc | |
| } | |
| }) | |
| return None | |
| def initialize_datasets(batch_size): | |
| """Initialize and return train and test datasets with dataloaders""" | |
| # Ensure data is downloaded and extracted | |
| print("Preparing dataset...") | |
| download_and_extract_mnist_data() | |
| # Paths to the extracted files | |
| train_images_path = "data/MNIST/raw/train-images-idx3-ubyte" | |
| train_labels_path = "data/MNIST/raw/train-labels-idx1-ubyte" | |
| test_images_path = "data/MNIST/raw/t10k-images-idx3-ubyte" | |
| test_labels_path = "data/MNIST/raw/t10k-labels-idx1-ubyte" | |
| # Data loading | |
| transform = transforms.Compose([ | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| train_dataset = CustomMNISTDataset(train_images_path, train_labels_path, transform=transform) | |
| test_dataset = CustomMNISTDataset(test_images_path, test_labels_path, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| return train_dataset, test_dataset, train_loader, test_loader | |
| async def start_comparison_training(websocket: WebSocket, parameters: dict): | |
| print("\n=== Starting Comparison Training ===") | |
| print(f"Received parameters: {json.dumps(parameters, indent=2)}") | |
| try: | |
| # Create models directory if it doesn't exist | |
| models_dir = Path("scripts/training/models") | |
| models_dir.mkdir(parents=True, exist_ok=True) | |
| # Validate parameters | |
| if not parameters.get('model_params'): | |
| print("Error: Missing model parameters") | |
| raise ValueError("Missing model parameters") | |
| if not parameters.get('dataset_params'): | |
| print("Error: Missing dataset parameters") | |
| raise ValueError("Missing dataset parameters") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| criterion = nn.CrossEntropyLoss() | |
| # Calculate total training samples once | |
| train_dataset = CustomMNISTDataset( | |
| "data/MNIST/raw/train-images-idx3-ubyte", | |
| "data/MNIST/raw/train-labels-idx1-ubyte", | |
| transform=transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))]) | |
| ) | |
| total_samples = len(train_dataset) | |
| # Dictionary to store best accuracies | |
| best_accuracies = {} | |
| # Start training models | |
| for model_key, model_letter in [('model_a', 'A'), ('model_b', 'B')]: | |
| print(f"\n{'='*50}") | |
| print(f"Training Model {model_letter}") | |
| print(f"{'='*50}") | |
| model_params = parameters['model_params'][model_key] | |
| # Calculate iterations per epoch for this model | |
| batch_size = model_params['batch_size'] | |
| iterations_per_epoch = total_samples // batch_size | |
| total_iterations = iterations_per_epoch * model_params['epochs'] | |
| # Print configuration details | |
| print("\nModel Configuration:") | |
| print(f"Architecture: {model_params['block1']}-{model_params['block2']}-{model_params['block3']}") | |
| print(f"Optimizer: {model_params['optimizer']}") | |
| print(f"Batch Size: {model_params['batch_size']}") | |
| print(f"Epochs: {model_params['epochs']}") | |
| print(f"Iterations per epoch: {iterations_per_epoch:,}") | |
| print(f"Total iterations: {total_iterations:,}") | |
| try: | |
| # Initialize datasets with model-specific batch size | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| test_dataset = CustomMNISTDataset( | |
| "data/MNIST/raw/t10k-images-idx3-ubyte", | |
| "data/MNIST/raw/t10k-labels-idx1-ubyte", | |
| transform=transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))]) | |
| ) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| print(f"\nDataset Information:") | |
| print(f"Training samples: {len(train_dataset):,}") | |
| print(f"Test samples: {len(test_dataset):,}") | |
| print(f"Steps per epoch: {len(train_loader):,}") | |
| # Initialize model and move to device | |
| model = Net(kernels=[ | |
| model_params['block1'], | |
| model_params['block2'], | |
| model_params['block3'] | |
| ]).to(device) | |
| # Print model parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"\nModel Parameters:") | |
| print(f"Total parameters: {total_params:,}") | |
| print(f"Trainable parameters: {trainable_params:,}") | |
| # Initialize optimizer | |
| if model_params['optimizer'].lower() == 'adam': | |
| optimizer = optim.Adam(model.parameters()) | |
| else: | |
| optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) | |
| # Train the model | |
| current_iteration = 0 | |
| best_acc = 0 # Track best accuracy for model saving | |
| for epoch in range(model_params['epochs']): | |
| model.train() | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| # Create progress bar for each epoch | |
| progress_bar = tqdm( | |
| train_loader, | |
| desc=f"Epoch {epoch+1}/{model_params['epochs']}", | |
| unit='batch', | |
| leave=True, | |
| ncols=100 | |
| ) | |
| for batch_idx, (data, target) in enumerate(progress_bar): | |
| data, target = data.to(device), target.to(device) | |
| optimizer.zero_grad() | |
| output = model(data) | |
| loss = criterion(output, target) | |
| loss.backward() | |
| optimizer.step() | |
| # Calculate batch accuracy | |
| pred = output.argmax(dim=1, keepdim=True) | |
| correct += pred.eq(target.view_as(pred)).sum().item() | |
| total += target.size(0) | |
| total_loss += loss.item() | |
| # Calculate current metrics | |
| current_loss = total_loss / (batch_idx + 1) | |
| current_acc = 100. * correct / total | |
| # Update progress bar description | |
| progress_bar.set_postfix({ | |
| 'loss': f'{current_loss:.4f}', | |
| 'acc': f'{current_acc:.2f}%' | |
| }) | |
| # Send comparison-specific training update | |
| current_iteration += 1 | |
| await websocket.send_json({ | |
| 'status': 'training', | |
| 'model': model_letter, | |
| 'metrics': { | |
| 'iteration': current_iteration, | |
| 'total_iterations': total_iterations, | |
| 'loss': current_loss, | |
| 'accuracy': current_acc | |
| }, | |
| 'epoch': epoch, | |
| 'batch_size': batch_size, | |
| 'iterations_per_epoch': iterations_per_epoch | |
| }) | |
| # Print epoch summary | |
| print(f"\nEpoch {epoch+1} Summary:") | |
| print(f"Average Loss: {current_loss:.4f}") | |
| print(f"Accuracy: {current_acc:.2f}%") | |
| # Add validation phase at the end of each epoch | |
| model.eval() | |
| val_loss = 0 | |
| val_correct = 0 | |
| val_total = 0 | |
| print("\nRunning validation...") | |
| with torch.no_grad(): | |
| for data, target in test_loader: | |
| data, target = data.to(device), target.to(device) | |
| output = model(data) | |
| val_loss += criterion(output, target).item() | |
| pred = output.argmax(dim=1, keepdim=True) | |
| val_correct += pred.eq(target.view_as(pred)).sum().item() | |
| val_total += target.size(0) | |
| val_loss /= len(test_loader) | |
| val_acc = 100. * val_correct / val_total | |
| # Save model if it's the best so far | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| # Generate filename with configuration | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| model_filename = f"{model_key}_arch_{model_params['block1']}_{model_params['block2']}_{model_params['block3']}_opt_{model_params['optimizer'].lower()}_batch_{model_params['batch_size']}_{timestamp}.pth" | |
| model_path = models_dir / model_filename | |
| print(f"\nSaving Model {model_letter} with accuracy {val_acc:.2f}% as: {model_filename}") | |
| torch.save(model.state_dict(), model_path) | |
| print(f"\nModel {model_letter} training completed") | |
| print(f"Best validation accuracy: {best_acc:.2f}%") | |
| # Save best accuracy for this model | |
| best_accuracies[model_key] = best_acc | |
| except Exception as e: | |
| print(f"Error training Model {model_letter}: {str(e)}") | |
| raise | |
| print("\nBoth models trained successfully") | |
| await websocket.send_json({ | |
| 'status': 'complete', | |
| 'message': 'Training completed for both models', | |
| 'model_a_acc': best_accuracies.get('model_a'), | |
| 'model_b_acc': best_accuracies.get('model_b') | |
| }) | |
| except Exception as e: | |
| error_msg = f"Error in comparison training: {str(e)}" | |
| print(error_msg) | |
| await websocket.send_json({ | |
| 'status': 'error', | |
| 'message': error_msg | |
| }) | |
| finally: | |
| print("=== Comparison Training Ended ===\n") | |