Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| import os | |
| import sys | |
| import timeit | |
| # Setup paths | |
| project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| sys.path.append(project_root) | |
| from data.dataset import VFVDataset | |
| from model.generator import QuantumGenerator | |
| from model.discriminator import Discriminator | |
| # --- FINAL CONFIG --- | |
| EPOCHS = 60 | |
| BATCH_SIZE = 64 | |
| LR_G = 0.0002 | |
| LR_D = 0.00005 | |
| CLIP_VALUE = 0.01 | |
| N_CRITIC = 1 | |
| CSV_PATH = os.path.join(project_root, "data", "vfv_market_data.csv") | |
| # --- SETUP --- | |
| dataset = VFVDataset(CSV_PATH) | |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) | |
| gen = QuantumGenerator() | |
| disc = Discriminator() | |
| # NOTE: | |
| # Our `Discriminator` already outputs a single scalar per sample (Linear -> 1). | |
| # In WGAN you remove a Sigmoid if present; we don't have one, so keep the last Linear. | |
| opt_G = optim.RMSprop(gen.parameters(), lr=LR_G) | |
| opt_D = optim.RMSprop(disc.parameters(), lr=LR_D) | |
| # This acts like a "parachute" for the sine wave, forcing it to land smoothly. | |
| scheduler_G = optim.lr_scheduler.StepLR(opt_G, step_size=15, gamma=0.5) | |
| scheduler_D = optim.lr_scheduler.StepLR(opt_D, step_size=15, gamma=0.5) | |
| w_distances = [] | |
| print(f"Starting Long-Run Training ({EPOCHS} Epochs)...") | |
| starttime = timeit.default_timer() | |
| for epoch in range(EPOCHS): | |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}") | |
| for i, real_windows in enumerate(pbar): | |
| real_windows = real_windows.float() | |
| batch_sz = real_windows.size(0) | |
| # --- TRAIN CRITIC --- | |
| opt_D.zero_grad() | |
| loss_real = -torch.mean(disc(real_windows)) | |
| loss_fake = torch.mean(disc(gen(batch_sz).detach())) | |
| d_loss = loss_real + loss_fake | |
| d_loss.backward() | |
| opt_D.step() | |
| for p in disc.parameters(): | |
| p.data.clamp_(-CLIP_VALUE, CLIP_VALUE) | |
| # --- TRAIN GENERATOR --- | |
| opt_G.zero_grad() | |
| g_loss = -torch.mean(disc(gen(batch_sz))) | |
| g_loss.backward() | |
| opt_G.step() | |
| # Metric: Wasserstein Distance | |
| dist = -d_loss.item() # approximate distance | |
| w_distances.append(dist) | |
| pbar.set_postfix(Gap=f"{dist:.4f}") | |
| # Step the schedulers at the end of the epoch | |
| scheduler_G.step() | |
| scheduler_D.step() | |
| # --- PLOT THE CONVERGENCE --- | |
| print(f"Training complete in {timeit.default_timer() - starttime}s. Generating plot...") | |
| plt.figure(figsize=(12, 6)) | |
| # 1. Determine Cutoff (Skip first 10% of training to hide the startup spike) | |
| cutoff = int(len(w_distances) * 0.1) | |
| zoomed_data = w_distances[cutoff:] | |
| zoomed_steps = range(cutoff, len(w_distances)) | |
| # 2. Plot Raw Gap (Zoomed) | |
| plt.plot(zoomed_steps, zoomed_data, color='purple', alpha=0.4, label="Raw Gap") | |
| # 3. Plot Trend Line (Moving Average) | |
| window = 50 | |
| if len(zoomed_data) > window: | |
| # Calculate moving avg on the ZOOMED data only | |
| moving_avg = [sum(zoomed_data[i:i+window])/window for i in range(len(zoomed_data)-window)] | |
| # Align x-axis | |
| plt.plot(range(cutoff+window, len(w_distances)), moving_avg, color='black', linewidth=2, label="Trend") | |
| # 4. Force Y-Axis Focus | |
| # We set the limits based on the ZOOMED data, ignoring the initial spike | |
| y_min = min(zoomed_data) | |
| y_max = max(zoomed_data) | |
| margin = (y_max - y_min) * 0.2 | |
| plt.ylim(y_min - margin, y_max + margin) | |
| plt.title(f"WGAN Convergence (Zoomed In - Last 90%)") | |
| plt.xlabel("Training Steps") | |
| plt.ylabel("Wasserstein Distance") | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.savefig("wgan_final_convergence.png") | |
| print("\nTraining Complete. Check wgan_final_convergence.png") | |
| torch.save(gen.state_dict(), "vfv_wgan_final.pt") |