Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import pandas as pd | |
| from model_training.model_torch import EncoderDecoder | |
| from tqdm import tqdm | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchmetrics.functional.image import structural_similarity_index_measure | |
| import numpy as np | |
| from config import normalize_input, get_input_fields | |
| from utils.path_utils import resolve_path | |
| import os | |
| class PhysicsTrajectoryDataset(Dataset): | |
| def __init__(self, filepath, physics_type): | |
| self.df = pd.read_pickle(filepath) | |
| self.physics_type = physics_type | |
| # Detect the number of time steps from the first sample | |
| sample_trajectory = self.df.iloc[0]['trajectory'] | |
| self.time_steps = len(sample_trajectory) | |
| # Sanity check | |
| if not all(len(traj) == self.time_steps for traj in self.df['trajectory']): | |
| raise ValueError("Not all trajectories have the same number of time steps!") | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| fields = get_input_fields(self.physics_type) | |
| inputs = torch.tensor(normalize_input(self.physics_type, *[row[f] for f in fields]), dtype=torch.float32) | |
| trajectory_np = np.array(row['trajectory']) | |
| trajectory = torch.from_numpy(trajectory_np).float() | |
| return inputs, trajectory | |
| def compute_ssim(pred, target, device): | |
| """ | |
| Computes SSIM between predicted and target tensors. | |
| If the input is 1D (H=1), attempts to squarify. | |
| Falls back to MSE-only if squarify fails. | |
| """ | |
| B, C, H, W = pred.shape | |
| try: | |
| if H == 1: | |
| # 1D sequence β squarify to [B, C, S, S] | |
| pred_sq = squarify_1d_sequence(pred) | |
| target_sq = squarify_1d_sequence(target) | |
| return structural_similarity_index_measure(pred_sq, target_sq, data_range=1.0) | |
| else: | |
| # Already 2D image-like | |
| return structural_similarity_index_measure(pred, target, data_range=1.0) | |
| except ValueError: | |
| # Fallback: SSIM not computable | |
| return torch.tensor(0.0, device=device) | |
| def squarify_1d_sequence(x): | |
| """Reshape 1D sequence into square 2D frames for SSIM.""" | |
| B, C, _, T = x.shape | |
| S = int(T ** 0.5) | |
| if S * S != T: | |
| raise ValueError(f"T={T} must be a perfect square for squarify, got {T}.") | |
| return x.view(B, C, S, S) | |
| def bce_dot_loss(pred, target, device): | |
| """ | |
| Compute binary cross entropy loss between predicted and target dot-maps. | |
| Assumes pred and target have shape [B, T, H, W] with values in [0, 1]. | |
| """ | |
| if pred.shape != target.shape: | |
| raise ValueError(f"Shape mismatch: pred {pred.shape} vs target {target.shape}") | |
| # Compute per-pixel weights: 6.0 where target=1, 1.0 where target=0 | |
| weight = (target * 10.0 + 1.0) # weight = 610 where dot is, 1 elsewhere | |
| loss = nn.functional.binary_cross_entropy(pred, target, weight=weight, reduction='mean') | |
| return loss | |
| def train_model( | |
| physics_type, | |
| hidden_size=128, | |
| lr=0.002, | |
| epochs=20, | |
| early_stopping=False, | |
| patience=5, | |
| batch_size=256, | |
| clip_grad=1.0 | |
| ): | |
| dataset_path = resolve_path(f"{physics_type}_data.pkl") | |
| model_path = resolve_path(f"{physics_type}_model.pth", write_mode=True) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Dataset and DataLoader | |
| dataset = PhysicsTrajectoryDataset(dataset_path, physics_type) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) | |
| # Inspect shape of sample | |
| sample_input, sample_target = dataset[0] | |
| timesteps, coord_dims = sample_target.shape | |
| assert coord_dims == 2, "Expected target shape [T, 2] for (x, y) coordinates" | |
| model = EncoderDecoder( | |
| input_dim=sample_input.shape[0], | |
| hidden_size=hidden_size, | |
| output_seq_len=timesteps, | |
| output_shape=None # Not used in coordinate mode | |
| ).to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3) | |
| loss_fn = nn.MSELoss() | |
| best_loss = float('inf') | |
| wait = 0 | |
| losses = [] | |
| for epoch in range(epochs): | |
| model.train() | |
| total_loss = 0 | |
| progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False) | |
| for batch_inputs, batch_targets in progress_bar: | |
| batch_inputs = batch_inputs.to(device) # [B, F] | |
| batch_targets = batch_targets.to(device) # [B, T, 2] | |
| optimizer.zero_grad() | |
| outputs = model(batch_inputs) # [B, T, 2] | |
| loss = loss_fn(outputs, batch_targets) | |
| loss.backward() | |
| if clip_grad: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| avg_loss = total_loss / len(dataloader) | |
| losses.append(avg_loss) | |
| print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}") | |
| scheduler.step(avg_loss) | |
| # Debug: check one sample trajectory | |
| # if epoch % 5 == 0: | |
| # with torch.no_grad(): | |
| # coords = outputs[0].detach().cpu().numpy() | |
| # print(f"π Predicted trajectory sample (epoch {epoch+1}):") | |
| # for t, (x, y) in enumerate(coords[:5]): | |
| # print(f"t={t}: (x={x:.3f}, y={y:.3f})") | |
| # Early Stopping Logic | |
| if early_stopping: | |
| if avg_loss < best_loss - 1e-6: | |
| best_loss = avg_loss | |
| wait = 0 | |
| else: | |
| wait += 1 | |
| if wait >= patience: | |
| print(f"βΉοΈ Early stopping at epoch {epoch+1}") | |
| break | |
| # Save model | |
| torch.save({ | |
| 'model_state': model.state_dict(), | |
| 'input_dim': sample_input.shape[0], | |
| 'output_seq_len': timesteps, | |
| }, model_path) | |
| print(f"β Model saved to {model_path}") | |
| return f"β Trained and saved {physics_type} model to {model_path}", losses | |
| if __name__ == "__main__": | |
| train_model("ball_motion") | |