flow-matching / test /debug_training.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
Raw
History Blame Contribute Delete
2.38 kB
import argparse
import sys
import unittest.mock
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
# Add Matcha-TTS to python path to access its modules
ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(ROOT / "Matcha-TTS"))
sys.path.append(str(ROOT))
import src.training
from src.stage1.medarc_architecture import MultiSubjectConvLinearEncoder
from src.stage2.CFM import CFM
from torch.utils.data import DataLoader, Dataset
class MockDataset(Dataset):
def __init__(
self, num_samples, num_subjects=4, time_steps=10, voxels=100, feat_dims=(32, 64)
):
self.num_samples = num_samples
self.num_subjects = num_subjects
self.time_steps = time_steps
self.voxels = voxels
self.feat_dims = feat_dims
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# features list
features = [torch.randn(self.time_steps, dim) for dim in self.feat_dims]
# fmri: (S, T, V)
fmri = torch.randn(self.num_subjects, self.time_steps, self.voxels)
return {"features": features, "fmri": fmri}
def mock_make_data_loaders(cfg):
print("MOCKING DATA LOADERS FOR DEBUG")
# Using small dimensions for debug
num_samples = 4
batch_size = cfg.batch_size
# Mock dimensions
voxels = 1000
feat_dims = (32, 64)
ds = MockDataset(num_samples=num_samples, voxels=voxels, feat_dims=feat_dims)
loader = DataLoader(ds, batch_size=batch_size)
return {"train": loader, "val_debug": loader} # Use same for val
def main():
# Patch the make_data_loaders in training.py
with unittest.mock.patch(
"src.training.make_data_loaders", side_effect=mock_make_data_loaders
):
# Manually set arguments to point to debug config
# Or better yet, call main() but intercept argument parsing?
# Since training.main() parses args, we can simulate command line args.
# Override sys.argv
sys.argv = ["training.py", "--cfg-path", "test/debug_config.yml"]
# Call original main
try:
src.training.main()
except Exception as e:
print(f"Caught exception during debug run: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()