File size: 3,778 Bytes
2f560eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import sys
import timeit

# Setup paths
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)

from data.dataset import VFVDataset
from model.generator import QuantumGenerator
from model.discriminator import Discriminator

# --- FINAL CONFIG ---
EPOCHS = 60         
BATCH_SIZE = 64
LR_G = 0.0002
LR_D = 0.00005
CLIP_VALUE = 0.01
N_CRITIC = 1 

CSV_PATH = os.path.join(project_root, "data", "vfv_market_data.csv")

# --- SETUP ---
dataset = VFVDataset(CSV_PATH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

gen = QuantumGenerator()
disc = Discriminator()
# NOTE:
# Our `Discriminator` already outputs a single scalar per sample (Linear -> 1).
# In WGAN you remove a Sigmoid if present; we don't have one, so keep the last Linear.

opt_G = optim.RMSprop(gen.parameters(), lr=LR_G)
opt_D = optim.RMSprop(disc.parameters(), lr=LR_D)

# This acts like a "parachute" for the sine wave, forcing it to land smoothly.
scheduler_G = optim.lr_scheduler.StepLR(opt_G, step_size=15, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(opt_D, step_size=15, gamma=0.5)

w_distances = [] 

print(f"Starting Long-Run Training ({EPOCHS} Epochs)...")
starttime = timeit.default_timer()

for epoch in range(EPOCHS):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for i, real_windows in enumerate(pbar):
        real_windows = real_windows.float()
        batch_sz = real_windows.size(0)
        
        # --- TRAIN CRITIC ---
        opt_D.zero_grad()
        loss_real = -torch.mean(disc(real_windows))
        loss_fake = torch.mean(disc(gen(batch_sz).detach()))
        d_loss = loss_real + loss_fake
        d_loss.backward()
        opt_D.step()

        for p in disc.parameters():
            p.data.clamp_(-CLIP_VALUE, CLIP_VALUE)

        # --- TRAIN GENERATOR ---
        opt_G.zero_grad()
        g_loss = -torch.mean(disc(gen(batch_sz)))
        g_loss.backward()
        opt_G.step()
        
        # Metric: Wasserstein Distance
        dist = -d_loss.item() # approximate distance
        w_distances.append(dist)
        pbar.set_postfix(Gap=f"{dist:.4f}")

    # Step the schedulers at the end of the epoch
    scheduler_G.step()
    scheduler_D.step()

# --- PLOT THE CONVERGENCE ---
print(f"Training complete in {timeit.default_timer() - starttime}s. Generating plot...")
plt.figure(figsize=(12, 6))

# 1. Determine Cutoff (Skip first 10% of training to hide the startup spike)
cutoff = int(len(w_distances) * 0.1) 
zoomed_data = w_distances[cutoff:]
zoomed_steps = range(cutoff, len(w_distances))

# 2. Plot Raw Gap (Zoomed)
plt.plot(zoomed_steps, zoomed_data, color='purple', alpha=0.4, label="Raw Gap")

# 3. Plot Trend Line (Moving Average)
window = 50
if len(zoomed_data) > window:
    # Calculate moving avg on the ZOOMED data only
    moving_avg = [sum(zoomed_data[i:i+window])/window for i in range(len(zoomed_data)-window)]
    # Align x-axis
    plt.plot(range(cutoff+window, len(w_distances)), moving_avg, color='black', linewidth=2, label="Trend")

# 4. Force Y-Axis Focus
# We set the limits based on the ZOOMED data, ignoring the initial spike
y_min = min(zoomed_data)
y_max = max(zoomed_data)
margin = (y_max - y_min) * 0.2
plt.ylim(y_min - margin, y_max + margin)

plt.title(f"WGAN Convergence (Zoomed In - Last 90%)")
plt.xlabel("Training Steps")
plt.ylabel("Wasserstein Distance")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig("wgan_final_convergence.png")
print("\nTraining Complete. Check wgan_final_convergence.png")
torch.save(gen.state_dict(), "vfv_wgan_final.pt")