Spaces:
Sleeping
Sleeping
File size: 3,778 Bytes
2f560eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import sys
import timeit
# Setup paths
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from data.dataset import VFVDataset
from model.generator import QuantumGenerator
from model.discriminator import Discriminator
# --- FINAL CONFIG ---
EPOCHS = 60
BATCH_SIZE = 64
LR_G = 0.0002
LR_D = 0.00005
CLIP_VALUE = 0.01
N_CRITIC = 1
CSV_PATH = os.path.join(project_root, "data", "vfv_market_data.csv")
# --- SETUP ---
dataset = VFVDataset(CSV_PATH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
gen = QuantumGenerator()
disc = Discriminator()
# NOTE:
# Our `Discriminator` already outputs a single scalar per sample (Linear -> 1).
# In WGAN you remove a Sigmoid if present; we don't have one, so keep the last Linear.
opt_G = optim.RMSprop(gen.parameters(), lr=LR_G)
opt_D = optim.RMSprop(disc.parameters(), lr=LR_D)
# This acts like a "parachute" for the sine wave, forcing it to land smoothly.
scheduler_G = optim.lr_scheduler.StepLR(opt_G, step_size=15, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(opt_D, step_size=15, gamma=0.5)
w_distances = []
print(f"Starting Long-Run Training ({EPOCHS} Epochs)...")
starttime = timeit.default_timer()
for epoch in range(EPOCHS):
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
for i, real_windows in enumerate(pbar):
real_windows = real_windows.float()
batch_sz = real_windows.size(0)
# --- TRAIN CRITIC ---
opt_D.zero_grad()
loss_real = -torch.mean(disc(real_windows))
loss_fake = torch.mean(disc(gen(batch_sz).detach()))
d_loss = loss_real + loss_fake
d_loss.backward()
opt_D.step()
for p in disc.parameters():
p.data.clamp_(-CLIP_VALUE, CLIP_VALUE)
# --- TRAIN GENERATOR ---
opt_G.zero_grad()
g_loss = -torch.mean(disc(gen(batch_sz)))
g_loss.backward()
opt_G.step()
# Metric: Wasserstein Distance
dist = -d_loss.item() # approximate distance
w_distances.append(dist)
pbar.set_postfix(Gap=f"{dist:.4f}")
# Step the schedulers at the end of the epoch
scheduler_G.step()
scheduler_D.step()
# --- PLOT THE CONVERGENCE ---
print(f"Training complete in {timeit.default_timer() - starttime}s. Generating plot...")
plt.figure(figsize=(12, 6))
# 1. Determine Cutoff (Skip first 10% of training to hide the startup spike)
cutoff = int(len(w_distances) * 0.1)
zoomed_data = w_distances[cutoff:]
zoomed_steps = range(cutoff, len(w_distances))
# 2. Plot Raw Gap (Zoomed)
plt.plot(zoomed_steps, zoomed_data, color='purple', alpha=0.4, label="Raw Gap")
# 3. Plot Trend Line (Moving Average)
window = 50
if len(zoomed_data) > window:
# Calculate moving avg on the ZOOMED data only
moving_avg = [sum(zoomed_data[i:i+window])/window for i in range(len(zoomed_data)-window)]
# Align x-axis
plt.plot(range(cutoff+window, len(w_distances)), moving_avg, color='black', linewidth=2, label="Trend")
# 4. Force Y-Axis Focus
# We set the limits based on the ZOOMED data, ignoring the initial spike
y_min = min(zoomed_data)
y_max = max(zoomed_data)
margin = (y_max - y_min) * 0.2
plt.ylim(y_min - margin, y_max + margin)
plt.title(f"WGAN Convergence (Zoomed In - Last 90%)")
plt.xlabel("Training Steps")
plt.ylabel("Wasserstein Distance")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig("wgan_final_convergence.png")
print("\nTraining Complete. Check wgan_final_convergence.png")
torch.save(gen.state_dict(), "vfv_wgan_final.pt") |