KTAparna commited on
Commit
479905a
·
verified ·
1 Parent(s): f29e167

Upload gnn_it_sector_timeseries.py

Browse files
Files changed (1) hide show
  1. gnn_it_sector_timeseries.py +710 -0
gnn_it_sector_timeseries.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple, List
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch import nn
8
+ from torch_geometric.data import Data, DataLoader
9
+ from torch_geometric.nn import GCNConv
10
+ from sklearn.preprocessing import StandardScaler
11
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
12
+ import matplotlib.pyplot as plt
13
+ import datetime as dt
14
+ import time
15
+
16
+
17
+ # -----------------------------------------------------
18
+ # 1. Data loading and preprocessing
19
+ # -----------------------------------------------------
20
+
21
+ def load_it_sector_data_from_csvs(
22
+ infy_csv: str,
23
+ tcs_csv: str,
24
+ nifty_it_csv: str,
25
+ ) -> Tuple[np.ndarray, np.ndarray, List[pd.Timestamp], List[str]]:
26
+ """Load IT sector data from separate CSV files and build cleaned feature + target tensors.
27
+
28
+ Methodology alignment
29
+ ---------------------
30
+ - Data Collection: uses OHLCV-style fields from the NSE IT sector file.
31
+ - Preprocessing / Cleaning:
32
+ * Parse dates and sort.
33
+ * Filter to equity series (EQ).
34
+ * Remove duplicates and rows with missing / invalid key values.
35
+ * Filter out non-trading days (zero / negative volume).
36
+ * Forward-fill remaining gaps.
37
+ - Derived Indicators:
38
+ * Daily returns.
39
+ * 5-day moving average of close.
40
+ * 20-day rolling volatility of returns.
41
+
42
+ Returns
43
+ -------
44
+ features : np.ndarray
45
+ Shape [num_dates, num_companies, num_features].
46
+ Features per company per day include normalized price/volume and indicators.
47
+ targets : np.ndarray
48
+ Shape [num_dates, num_companies]. Daily returns per company (prediction target).
49
+ dates : list of pd.Timestamp
50
+ Trading dates.
51
+ companies : list of str
52
+ List of company tickers (node names in the graph).
53
+ """
54
+ # ---------------------------------
55
+ # Load individual CSVs
56
+ # ---------------------------------
57
+ infy_df = pd.read_csv(infy_csv)
58
+ tcs_df = pd.read_csv(tcs_csv)
59
+ index_df = pd.read_csv(nifty_it_csv)
60
+
61
+ # Add a Company identifier manually
62
+ infy_df["Company"] = "INFY"
63
+ tcs_df["Company"] = "TCS"
64
+ index_df["Company"] = "NIFTY_IT"
65
+
66
+ # Harmonize columns where needed for the index
67
+ # Ensure required OHLCV columns exist (use Close/Volume, ignore others if missing)
68
+ for df in [infy_df, tcs_df, index_df]:
69
+ df["Date"] = pd.to_datetime(df["Date"])
70
+
71
+ # For the index, mimic equity-style columns for compatibility
72
+ if "Series" not in index_df.columns:
73
+ index_df["Series"] = "EQ"
74
+ if "Close" not in index_df.columns and "Close" in index_df.columns:
75
+ # Already present; this branch is just a safety net
76
+ pass
77
+ if "Volume" not in index_df.columns and "Volume" in index_df.columns:
78
+ # Already present; just a safety net
79
+ pass
80
+
81
+ # Unify columns to a common subset
82
+ common_cols = [
83
+ "Date",
84
+ "Company",
85
+ "Series",
86
+ "Open",
87
+ "High",
88
+ "Low",
89
+ "Close",
90
+ "Volume",
91
+ ]
92
+
93
+ # For stock CSVs, ensure the above columns are present
94
+ for stock_df in [infy_df, tcs_df]:
95
+ # They already have Symbol, Series, Prev Close, Open, High, Low, Last, Close, VWAP, Volume, ...
96
+ # We just keep the columns we need and drop the rest later.
97
+ pass
98
+
99
+ # For index, keep only the needed OHLCV columns and Series/Company
100
+ index_df = index_df[["Date", "Open", "High", "Low", "Close", "Volume", "Company", "Series"]]
101
+
102
+ # Make sure column order matches common_cols
103
+ index_df = index_df[["Date", "Company", "Series", "Open", "High", "Low", "Close", "Volume"]]
104
+
105
+ # Align stock DataFrames to the same schema
106
+ infy_df = infy_df[["Date", "Company", "Series", "Open", "High", "Low", "Close", "Volume"]]
107
+ tcs_df = tcs_df[["Date", "Company", "Series", "Open", "High", "Low", "Close", "Volume"]]
108
+
109
+ # Concatenate all into one panel-like table
110
+ df = pd.concat([infy_df, tcs_df, index_df], ignore_index=True)
111
+
112
+ # -------------------------
113
+ # Basic cleaning steps
114
+ # -------------------------
115
+ # Ensure proper dtypes and ordering
116
+ df["Date"] = pd.to_datetime(df["Date"])
117
+
118
+ # Keep only equity series
119
+ if "Series" in df.columns:
120
+ df = df[df["Series"] == "EQ"]
121
+
122
+ # Drop rows with critical missing values
123
+ df = df.dropna(subset=["Company", "Close", "Volume", "Open", "High", "Low"])
124
+
125
+ # Remove zero / negative volume (non-trading or bad records)
126
+ df = df[df["Volume"] > 0]
127
+
128
+ # Drop exact duplicates on (Date, Company)
129
+ df = df.drop_duplicates(subset=["Date", "Company"])
130
+
131
+ # Sort by date then company
132
+ df = df.sort_values(["Date", "Company"])
133
+
134
+ # Use the "Company" column as canonical ticker (INFY, TCS, HCLTECH, TECHM, WIPRO, ...)
135
+ companies = sorted(df["Company"].unique().tolist())
136
+
137
+ # Pivot to Date x Company for OHLCV-like data
138
+ close = df.pivot_table(index="Date", columns="Company", values="Close")
139
+ volume = df.pivot_table(index="Date", columns="Company", values="Volume")
140
+
141
+ # Ensure consistent column order
142
+ close = close[companies]
143
+ volume = volume[companies]
144
+
145
+ # Forward-fill missing values along time for each company
146
+ close = close.ffill()
147
+ volume = volume.ffill()
148
+
149
+ # -------------------------
150
+ # Derived indicators
151
+ # -------------------------
152
+ # 1-day simple returns (percentage change)
153
+ returns = close.pct_change().replace([np.inf, -np.inf], np.nan).fillna(0.0)
154
+
155
+ # 5-day moving average of closing price (trend)
156
+ ma5 = close.rolling(window=5, min_periods=1).mean().ffill()
157
+
158
+ # 20-day rolling volatility of returns (risk)
159
+ vol20 = (
160
+ returns.rolling(window=20, min_periods=1)
161
+ .std()
162
+ .replace([np.inf, -np.inf], np.nan)
163
+ .fillna(0.0)
164
+ .ffill()
165
+ )
166
+
167
+ # -------------------------
168
+ # Normalization per company
169
+ # -------------------------
170
+ scaler_close = StandardScaler()
171
+ scaler_vol = StandardScaler()
172
+ scaler_ma5 = StandardScaler()
173
+ scaler_vol20 = StandardScaler()
174
+
175
+ close_scaled = pd.DataFrame(
176
+ scaler_close.fit_transform(close.values),
177
+ index=close.index,
178
+ columns=close.columns,
179
+ )
180
+ volume_scaled = pd.DataFrame(
181
+ scaler_vol.fit_transform(volume.values),
182
+ index=volume.index,
183
+ columns=volume.columns,
184
+ )
185
+ ma5_scaled = pd.DataFrame(
186
+ scaler_ma5.fit_transform(ma5.values),
187
+ index=ma5.index,
188
+ columns=ma5.columns,
189
+ )
190
+ vol20_scaled = pd.DataFrame(
191
+ scaler_vol20.fit_transform(vol20.values),
192
+ index=vol20.index,
193
+ columns=vol20.columns,
194
+ )
195
+
196
+ dates = close.index.to_list()
197
+ num_dates = len(dates)
198
+ num_companies = len(companies)
199
+
200
+ # Features per node per day:
201
+ # [normalized close, normalized volume, raw return, normalized MA5, normalized VOL20]
202
+ num_features = 5
203
+ features = np.zeros((num_dates, num_companies, num_features), dtype=np.float32)
204
+
205
+ for j, c in enumerate(companies):
206
+ features[:, j, 0] = close_scaled[c].values
207
+ features[:, j, 1] = volume_scaled[c].values
208
+ features[:, j, 2] = returns[c].values
209
+ features[:, j, 3] = ma5_scaled[c].values
210
+ features[:, j, 4] = vol20_scaled[c].values
211
+
212
+ targets = returns.values.astype(np.float32) # predict daily returns
213
+
214
+ return features, targets, dates, companies
215
+
216
+
217
+ # -----------------------------------------------------
218
+ # 2. Graph construction (correlation-based)
219
+ # -----------------------------------------------------
220
+
221
+ def build_correlation_graph(returns: np.ndarray, threshold: float = 0.2) -> torch.Tensor:
222
+ """Build an undirected graph of companies based on return correlations.
223
+
224
+ Parameters
225
+ ----------
226
+ returns : np.ndarray
227
+ Array of shape [num_dates, num_companies] with daily returns.
228
+ threshold : float
229
+ Minimum absolute correlation to create an edge.
230
+
231
+ Returns
232
+ -------
233
+ edge_index : torch.Tensor
234
+ Tensor of shape [2, num_edges] in COO format for PyTorch Geometric.
235
+ """
236
+ # Correlation across companies
237
+ corr = np.corrcoef(returns.T) # [num_companies, num_companies]
238
+ num_nodes = corr.shape[0]
239
+
240
+ edge_index_list = []
241
+ for i in range(num_nodes):
242
+ for j in range(num_nodes):
243
+ if i == j:
244
+ continue
245
+ if np.abs(corr[i, j]) >= threshold:
246
+ edge_index_list.append([i, j])
247
+
248
+ # Fallback: fully-connected graph (without self-loops) if threshold is too high
249
+ if len(edge_index_list) == 0:
250
+ for i in range(num_nodes):
251
+ for j in range(num_nodes):
252
+ if i != j:
253
+ edge_index_list.append([i, j])
254
+
255
+ edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
256
+ return edge_index
257
+
258
+
259
+ # -----------------------------------------------------
260
+ # 3. Dataset for time-windowed graph snapshots
261
+ # -----------------------------------------------------
262
+
263
+
264
+ class TimeSeriesGraphDataset(torch.utils.data.Dataset):
265
+ """Dataset that converts time series into windowed graph snapshots for GNNs.
266
+
267
+ Each item is a Data object with:
268
+ - x: node features [num_nodes, window_size * num_features]
269
+ - edge_index: static company correlation graph
270
+ - y: target returns [num_nodes]
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ features: np.ndarray,
276
+ targets: np.ndarray,
277
+ edge_index: torch.Tensor,
278
+ window_size: int,
279
+ start_t: int,
280
+ end_t: int,
281
+ ) -> None:
282
+ super().__init__()
283
+ self.features = features
284
+ self.targets = targets
285
+ self.edge_index = edge_index
286
+ self.window_size = window_size
287
+ self.start_t = start_t
288
+ self.end_t = end_t
289
+
290
+ def __len__(self) -> int:
291
+ return self.end_t - self.start_t
292
+
293
+ def __getitem__(self, idx: int) -> Data:
294
+ t = self.start_t + idx
295
+ # Use previous `window_size` days to predict returns at day t
296
+ window_feats = self.features[t - self.window_size : t] # [W, N, F]
297
+ window, num_nodes, num_feat = window_feats.shape
298
+
299
+ # Keep the temporal dimension for LSTM-based encoding.
300
+ # Shape: [num_nodes, window, num_feat]
301
+ x_seq = window_feats.transpose(1, 0, 2)
302
+ y = self.targets[t] # [num_nodes]
303
+
304
+ data = Data(
305
+ x=torch.from_numpy(x_seq), # [num_nodes, window, num_feat]
306
+ edge_index=self.edge_index,
307
+ y=torch.from_numpy(y),
308
+ )
309
+ return data
310
+
311
+
312
+ # -----------------------------------------------------
313
+ # 4. GNN model definition (GCN for regression)
314
+ # -----------------------------------------------------
315
+
316
+
317
+ class GNNTimeSeriesModel(nn.Module):
318
+ """LSTM + GCN hybrid for multi-node time-series regression.
319
+
320
+ Methodology alignment
321
+ ---------------------
322
+ - Temporal Feature Extraction: shared LSTM encodes each stock's past W days.
323
+ - GNN Application: GCN layers propagate information over the inter-stock graph.
324
+ - Prediction: per-node regression head outputs next-day return.
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ window_size: int,
330
+ num_features: int,
331
+ hidden_lstm: int = 64,
332
+ hidden_gnn: int = 64,
333
+ dropout: float = 0.2,
334
+ ) -> None:
335
+ super().__init__()
336
+ self.window_size = window_size
337
+ self.num_features = num_features
338
+
339
+ # Temporal encoder: LSTM over W x F for each stock
340
+ self.lstm = nn.LSTM(
341
+ input_size=num_features,
342
+ hidden_size=hidden_lstm,
343
+ num_layers=1,
344
+ batch_first=False, # we will feed [W, N, F]
345
+ )
346
+
347
+ # Graph convolution layers operating on LSTM embeddings
348
+ self.conv1 = GCNConv(hidden_lstm, hidden_gnn)
349
+ self.conv2 = GCNConv(hidden_gnn, hidden_gnn)
350
+ self.lin = nn.Linear(hidden_gnn, 1)
351
+ self.dropout = nn.Dropout(dropout)
352
+
353
+ def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
354
+ """Forward pass.
355
+
356
+ Parameters
357
+ ----------
358
+ x : torch.Tensor
359
+ Shape [num_nodes_total_in_batch, window, num_features].
360
+ edge_index : torch.Tensor
361
+ Graph edges for the batched graph.
362
+ """
363
+ # -----------------------------
364
+ # Temporal feature extraction
365
+ # -----------------------------
366
+ # x_seq: [num_nodes_total, window, num_features]
367
+ num_nodes_total, window, num_feat = x.shape
368
+ assert (
369
+ window == self.window_size and num_feat == self.num_features
370
+ ), "Input window/feature dims do not match model configuration."
371
+
372
+ # LSTM expects [seq_len, batch, input_size]
373
+ x_seq = x.permute(1, 0, 2) # [window, num_nodes_total, num_features]
374
+ _, (h_n, _) = self.lstm(x_seq)
375
+
376
+ # Last layer hidden state: [num_nodes_total, hidden_lstm]
377
+ h_last = h_n[-1]
378
+
379
+ # -----------------------------
380
+ # Graph convolution over stocks
381
+ # -----------------------------
382
+ x_g = self.conv1(h_last, edge_index)
383
+ x_g = torch.relu(x_g)
384
+ x_g = self.dropout(x_g)
385
+
386
+ x_g = self.conv2(x_g, edge_index)
387
+ x_g = torch.relu(x_g)
388
+ x_g = self.dropout(x_g)
389
+
390
+ out = self.lin(x_g).squeeze(-1) # [num_nodes_total]
391
+ return out
392
+
393
+
394
+ # -----------------------------------------------------
395
+ # 5. Training and evaluation utilities
396
+ # -----------------------------------------------------
397
+
398
+
399
+ def train_one_epoch(
400
+ model: nn.Module,
401
+ loader: DataLoader,
402
+ optimizer: torch.optim.Optimizer,
403
+ device: torch.device,
404
+ ) -> float:
405
+ model.train()
406
+ criterion = nn.MSELoss()
407
+ total_loss = 0.0
408
+
409
+ for batch in loader:
410
+ batch = batch.to(device)
411
+ optimizer.zero_grad()
412
+ out = model(batch.x, batch.edge_index)
413
+ loss = criterion(out, batch.y)
414
+ loss.backward()
415
+ optimizer.step()
416
+ total_loss += loss.item() * batch.num_graphs
417
+
418
+ avg_loss = total_loss / len(loader.dataset)
419
+ return avg_loss
420
+
421
+
422
+ def evaluate(
423
+ model: nn.Module,
424
+ loader: DataLoader,
425
+ device: torch.device,
426
+ ):
427
+ model.eval()
428
+ criterion = nn.MSELoss()
429
+ total_loss = 0.0
430
+ all_y_true = []
431
+ all_y_pred = []
432
+
433
+ with torch.no_grad():
434
+ for batch in loader:
435
+ batch = batch.to(device)
436
+ out = model(batch.x, batch.edge_index)
437
+ loss = criterion(out, batch.y)
438
+ total_loss += loss.item() * batch.num_graphs
439
+ all_y_true.append(batch.y.cpu().numpy())
440
+ all_y_pred.append(out.cpu().numpy())
441
+
442
+ y_true = np.concatenate(all_y_true)
443
+ y_pred = np.concatenate(all_y_pred)
444
+
445
+ # -------------------------------------------------
446
+ # Guard against NaN/Inf in predictions or targets
447
+ # -------------------------------------------------
448
+ mask = np.isfinite(y_true) & np.isfinite(y_pred)
449
+ if mask.sum() == 0:
450
+ # Fallback: avoid crashing; metrics will be NaN but training can continue
451
+ mse = float("nan")
452
+ mae = float("nan")
453
+ directional_accuracy = float("nan")
454
+ avg_loss = total_loss / max(len(loader.dataset), 1)
455
+ return avg_loss, mse, mae, directional_accuracy, y_true, y_pred
456
+
457
+ y_true_clean = y_true[mask]
458
+ y_pred_clean = y_pred[mask]
459
+
460
+ mse = mean_squared_error(y_true_clean, y_pred_clean)
461
+ mae = mean_absolute_error(y_true_clean, y_pred_clean)
462
+ # Directional accuracy: how often the sign of return is predicted correctly
463
+ directional_accuracy = float((np.sign(y_true_clean) == np.sign(y_pred_clean)).mean())
464
+
465
+ avg_loss = total_loss / len(loader.dataset)
466
+ return avg_loss, mse, mae, directional_accuracy, y_true_clean, y_pred_clean
467
+
468
+
469
+ # -----------------------------------------------------
470
+ # 6. Baseline (before GNN) and real-time helpers
471
+ # -----------------------------------------------------
472
+
473
+
474
+ def compute_naive_baseline_metrics(targets: np.ndarray, train_start: int, train_end: int, val_start: int, val_end: int, test_start: int, test_end: int):
475
+ """Compute a simple baseline: predict zero return (no change) and plot vs actual.
476
+
477
+ This represents a "before GNN" naive model where we assume next-day return = 0.
478
+ """
479
+ # Flatten across all nodes
480
+ y_train = targets[train_start:train_end].reshape(-1)
481
+ y_val = targets[val_start:val_end].reshape(-1)
482
+ y_test = targets[test_start:test_end].reshape(-1)
483
+
484
+ # Baseline predictions are all zeros
485
+ y_train_pred = np.zeros_like(y_train)
486
+ y_val_pred = np.zeros_like(y_val)
487
+ y_test_pred = np.zeros_like(y_test)
488
+
489
+ train_mse = mean_squared_error(y_train, y_train_pred)
490
+ val_mse = mean_squared_error(y_val, y_val_pred)
491
+ test_mse = mean_squared_error(y_test, y_test_pred)
492
+
493
+ # Plot for test set
494
+ plt.figure(figsize=(6, 6))
495
+ plt.scatter(y_test, y_test_pred, alpha=0.3, s=10)
496
+ plt.xlabel("Actual returns")
497
+ plt.ylabel("Predicted returns (baseline: 0)")
498
+ plt.title("Baseline (No GNN) Predicted vs Actual Returns")
499
+ lims = [min(y_test.min(), y_test_pred.min()), max(y_test.max(), y_test_pred.max())]
500
+ plt.plot(lims, lims, "r--", linewidth=1)
501
+ plt.tight_layout()
502
+ plt.savefig("baseline_pred_vs_actual.png", dpi=200)
503
+ plt.close()
504
+
505
+ print(f"Baseline Train MSE: {train_mse:.6f}, Val MSE: {val_mse:.6f}, Test MSE: {test_mse:.6f}")
506
+ print("Saved baseline scatter plot to baseline_pred_vs_actual.png")
507
+
508
+
509
+ def realtime_predict_last_window(
510
+ model: nn.Module,
511
+ features: np.ndarray,
512
+ edge_index: torch.Tensor,
513
+ window_size: int,
514
+ device: torch.device,
515
+ ):
516
+ """Generate a real-time style prediction for the latest available day.
517
+
518
+ This uses the most recent `window_size` days in `features` as if it were "live" data.
519
+ """
520
+ model.eval()
521
+ num_dates, num_nodes, num_feat = features.shape
522
+ if num_dates < window_size:
523
+ raise ValueError("Not enough data points for real-time window prediction.")
524
+
525
+ # Last window
526
+ window_feats = features[num_dates - window_size : num_dates] # [W, N, F]
527
+ window, N, F = window_feats.shape
528
+ x_seq = window_feats.transpose(1, 0, 2) # [N, W, F]
529
+
530
+ data = Data(
531
+ x=torch.from_numpy(x_seq).to(device),
532
+ edge_index=edge_index.to(device),
533
+ )
534
+
535
+ with torch.no_grad():
536
+ out = model(data.x, data.edge_index).cpu().numpy()
537
+
538
+ return out # [num_nodes]
539
+
540
+
541
+ # -----------------------------------------------------
542
+ # 7. Main experiment pipeline
543
+ # -----------------------------------------------------
544
+
545
+
546
+ def main():
547
+ infy_csv = "infy_stock.csv"
548
+ tcs_csv = "tcs_stock.csv"
549
+ nifty_it_csv = "nifty_it_index.csv"
550
+ for p in [infy_csv, tcs_csv, nifty_it_csv]:
551
+ if not os.path.exists(p):
552
+ raise FileNotFoundError(f"Could not find required CSV file: {p}")
553
+
554
+ print("Loading and preprocessing data from CSVs...")
555
+ features, targets, dates, companies = load_it_sector_data_from_csvs(
556
+ infy_csv=infy_csv,
557
+ tcs_csv=tcs_csv,
558
+ nifty_it_csv=nifty_it_csv,
559
+ )
560
+ num_dates, num_companies, num_features = features.shape
561
+ print(f"Num dates: {num_dates}, Num companies (nodes): {num_companies}, Num features: {num_features}")
562
+
563
+ # Build graph from training-period correlations only (to avoid look-ahead bias)
564
+ window_size = 20
565
+ if num_dates <= window_size + 1:
566
+ raise ValueError("Not enough dates to create time windows. Reduce window_size or use more data.")
567
+
568
+ first_t = window_size
569
+ last_t = num_dates - 1
570
+ total_samples = last_t - first_t + 1
571
+
572
+ train_samples = int(total_samples * 0.7)
573
+ val_samples = int(total_samples * 0.15)
574
+ test_samples = total_samples - train_samples - val_samples
575
+
576
+ train_start_t = first_t
577
+ train_end_t = train_start_t + train_samples
578
+ val_start_t = train_end_t
579
+ val_end_t = val_start_t + val_samples
580
+ test_start_t = val_end_t
581
+ test_end_t = last_t + 1
582
+
583
+ print(f"Total usable samples: {total_samples}")
584
+ print(f"Train: {train_samples}, Val: {val_samples}, Test: {test_samples}")
585
+
586
+ # -----------------------------
587
+ # Baseline (before GNN)
588
+ # -----------------------------
589
+ compute_naive_baseline_metrics(
590
+ targets,
591
+ train_start=train_start_t,
592
+ train_end=train_end_t,
593
+ val_start=val_start_t,
594
+ val_end=val_end_t,
595
+ test_start=test_start_t,
596
+ test_end=test_end_t,
597
+ )
598
+
599
+ # Use only training period to compute correlations
600
+ train_returns = targets[train_start_t:train_end_t]
601
+ edge_index = build_correlation_graph(train_returns, threshold=0.2)
602
+ print("Edge index shape:", edge_index.shape)
603
+
604
+ # Create datasets
605
+ train_dataset = TimeSeriesGraphDataset(
606
+ features=features,
607
+ targets=targets,
608
+ edge_index=edge_index,
609
+ window_size=window_size,
610
+ start_t=train_start_t,
611
+ end_t=train_end_t,
612
+ )
613
+
614
+ val_dataset = TimeSeriesGraphDataset(
615
+ features=features,
616
+ targets=targets,
617
+ edge_index=edge_index,
618
+ window_size=window_size,
619
+ start_t=val_start_t,
620
+ end_t=val_end_t,
621
+ )
622
+
623
+ test_dataset = TimeSeriesGraphDataset(
624
+ features=features,
625
+ targets=targets,
626
+ edge_index=edge_index,
627
+ window_size=window_size,
628
+ start_t=test_start_t,
629
+ end_t=test_end_t,
630
+ )
631
+
632
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
633
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
634
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
635
+
636
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
637
+ print("Using device:", device)
638
+
639
+ model = GNNTimeSeriesModel(
640
+ window_size=window_size,
641
+ num_features=num_features,
642
+ hidden_lstm=64,
643
+ hidden_gnn=64,
644
+ dropout=0.2,
645
+ ).to(device)
646
+
647
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
648
+
649
+ num_epochs = 30
650
+ best_val_loss = float("inf")
651
+ best_state_dict = None
652
+
653
+ print("Starting training...")
654
+ for epoch in range(1, num_epochs + 1):
655
+ train_loss = train_one_epoch(model, train_loader, optimizer, device)
656
+ val_loss, val_mse, val_mae, val_dir_acc, _, _ = evaluate(model, val_loader, device)
657
+
658
+ if val_loss < best_val_loss:
659
+ best_val_loss = val_loss
660
+ best_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}
661
+
662
+ print(
663
+ f"Epoch {epoch:03d} | "
664
+ f"Train Loss: {train_loss:.6f} | "
665
+ f"Val Loss: {val_loss:.6f}, Val MSE: {val_mse:.6f}, Val MAE: {val_mae:.6f}, "
666
+ f"Val DirAcc: {val_dir_acc:.4f}"
667
+ )
668
+
669
+ if best_state_dict is not None:
670
+ model.load_state_dict(best_state_dict)
671
+
672
+ print("Evaluating on test set...")
673
+ test_loss, test_mse, test_mae, test_dir_acc, y_true, y_pred = evaluate(model, test_loader, device)
674
+ print(
675
+ f"Test Loss: {test_loss:.6f}, Test MSE: {test_mse:.6f}, "
676
+ f"Test MAE: {test_mae:.6f}, Test DirAcc: {test_dir_acc:.4f}"
677
+ )
678
+
679
+ # -------------------------------------------------
680
+ # Simple visualization: predicted vs actual returns
681
+ # -------------------------------------------------
682
+ plt.figure(figsize=(6, 6))
683
+ plt.scatter(y_true, y_pred, alpha=0.3, s=10)
684
+ plt.xlabel("Actual returns")
685
+ plt.ylabel("Predicted returns")
686
+ plt.title("GNN Predicted vs Actual Daily Returns (All IT Stocks)")
687
+ lims = [min(y_true.min(), y_pred.min()), max(y_true.max(), y_pred.max())]
688
+ plt.plot(lims, lims, "r--", linewidth=1)
689
+ plt.tight_layout()
690
+ plt.savefig("gnn_it_sector_pred_vs_actual.png", dpi=200)
691
+ plt.close()
692
+ print("Saved scatter plot to gnn_it_sector_pred_vs_actual.png")
693
+
694
+ # -------------------------------------------------
695
+ # Real-time style prediction using latest window
696
+ # -------------------------------------------------
697
+ latest_pred = realtime_predict_last_window(
698
+ model=model,
699
+ features=features,
700
+ edge_index=edge_index,
701
+ window_size=window_size,
702
+ device=device,
703
+ )
704
+ print("Real-time style next-day return prediction per node (order of companies):")
705
+ for comp, val in zip(companies, latest_pred):
706
+ print(f" {comp}: {val:.6f}")
707
+
708
+
709
+ if __name__ == "__main__":
710
+ main()