Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import torch | |
| from transformer import GPT, GPTConfig, DataLoaderLite # Import your model and data loader | |
| # Initialize the model and data loader | |
| config = GPTConfig() | |
| model = GPT(config) | |
| train_loader = DataLoaderLite(B=4, T=1024) | |
| # Define the optimizer | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) | |
| # Function to load the most recent checkpoint | |
| def load_latest_checkpoint(model): | |
| checkpoint_file = 'checkpoint.pt' | |
| if not os.path.exists(checkpoint_file): | |
| return 0 # No checkpoint found, start from epoch 0 | |
| print(f'Loading checkpoint from {checkpoint_file}') | |
| checkpoint = torch.load(checkpoint_file) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| return checkpoint['epoch'] | |
| # Load the latest checkpoint if available | |
| start_epoch = load_latest_checkpoint(model) | |
| # Training loop | |
| num_epochs = 91 | |
| # Start time tracking | |
| start_time = time.time() | |
| for epoch in range(start_epoch, num_epochs): # Start from the loaded epoch | |
| epoch_loss = 0.0 # Initialize epoch loss | |
| num_steps = 0 # Initialize step counter for the epoch | |
| last_loss = None # Variable to store the last loss | |
| # Calculate total steps for the progress bar | |
| total_steps = len(train_loader.tokens) // (train_loader.B * train_loader.T) | |
| # Use tqdm to create a progress bar | |
| with tqdm(total=total_steps, desc=f'Epoch {epoch + 1}/{num_epochs}') as pbar: | |
| for step in range(total_steps): # Iterate over the number of steps | |
| x, y = train_loader.next_batch() | |
| x, y = x.to(device), y.to(device) | |
| optimizer.zero_grad() | |
| logits, loss = model(x, y) | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() # Accumulate loss | |
| num_steps += 1 # Increment step counter | |
| last_loss = loss.item() # Store the last loss | |
| pbar.update(1) # Update progress bar | |
| # Check if the loss is below the threshold | |
| if last_loss < 0.099999: | |
| print(f'Loss below threshold: {last_loss:.6f}') # Print loss before breaking | |
| break # Exit the loop if the loss condition is met | |
| # Print the loss at the end of the epoch | |
| print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {last_loss:.6f}') | |
| # Check if the loss condition was met to break out of the epoch loop | |
| if last_loss < 0.099999: | |
| print(f'Early stopping at epoch {epoch + 1} due to loss condition met.') | |
| break # Exit the epoch loop if the loss condition is met | |
| # Checkpointing: Save the model and the current epoch after each epoch | |
| checkpoint_path = 'checkpoint.pt' # Save to a single checkpoint file | |
| torch.save({ | |
| 'epoch': epoch + 1, # Save the current epoch number | |
| 'model_state_dict': model.state_dict(), # Save the model state | |
| }, checkpoint_path) | |
| print(f'Checkpoint saved to {checkpoint_path}') | |
| # End time tracking | |
| end_time = time.time() | |
| training_duration = end_time - start_time | |
| # Convert training duration to minutes and seconds | |
| minutes = int(training_duration // 60) | |
| seconds = int(training_duration % 60) | |
| # Print the total training time in minute:second format | |
| print(f'Total training time: {minutes} minutes and {seconds} seconds') | |
| # After training your model, apply quantization and save it with compression | |
| def save_model_with_quantization(model, file_path): | |
| # Switch model to evaluation mode | |
| model.eval() | |
| # Apply dynamic quantization | |
| quantized_model = torch.quantization.quantize_dynamic( | |
| model, # the model to be quantized | |
| {nn.Linear}, # layers to quantize | |
| dtype=torch.qint8 # quantization type | |
| ) | |
| # Save the quantized model with compression | |
| torch.save(quantized_model.state_dict(), file_path, _use_new_zipfile_serialization=True) | |
| print(f'Model saved to {file_path} with quantization and compression.') | |
| # Call this function after training your model | |
| save_model_with_quantization(model, 'trained_model_quantized.pt') | |