Spaces:
Sleeping
Sleeping
| from torch.utils.data import DataLoader | |
| from .utils.data import FFTDataset, SplitDataset | |
| from datasets import load_dataset | |
| from .utils.train import Trainer | |
| from .utils.models import CNNKan, KanEncoder | |
| from .utils.data_utils import * | |
| from huggingface_hub import login | |
| import yaml | |
| import datetime | |
| import json | |
| import numpy as np | |
| # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| current_date = datetime.date.today().strftime("%Y-%m-%d") | |
| datetime_dir = f"frugal_{current_date}" | |
| args_dir = 'tasks/utils/config.yaml' | |
| data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data']) | |
| exp_num = data_args.exp_num | |
| model_name = data_args.model_name | |
| model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder']) | |
| model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f']) | |
| conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer']) | |
| kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN']) | |
| if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"): | |
| os.makedirs(f"{data_args.log_dir}/{datetime_dir}") | |
| with open("../logs//token.txt", "r") as f: | |
| api_key = f.read() | |
| # local_rank, world_size, gpus_per_node = setup() | |
| local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| login(api_key) | |
| dataset = load_dataset("rfcx/frugalai", streaming=True) | |
| train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True) | |
| train_dl = DataLoader(train_ds, batch_size=data_args.batch_size, collate_fn=collate_fn) | |
| val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False) | |
| val_dl = DataLoader(val_ds,batch_size=data_args.batch_size, collate_fn=collate_fn) | |
| test_ds = FFTDataset(dataset["test"]) | |
| test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn) | |
| # for i, batch in enumerate(train_dl): | |
| # x, x_f, y = batch['audio']['array'], batch['audio']['fft'], batch['label'] | |
| # print(x.shape, x_f.shape, y.shape) | |
| # if i > 10: | |
| # break | |
| # exit() | |
| # model = DualEncoder(model_args, model_args_f, conformer_args) | |
| # model = FasterKAN([18000,64,64,16,1]) | |
| model = CNNKan(model_args, conformer_args, kan_args.get_dict()) | |
| # model.kan.speed() | |
| # model = KanEncoder(kan_args.get_dict()) | |
| model = model.to(local_rank) | |
| # model = DDP(model, device_ids=[local_rank], output_device=local_rank) | |
| num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Number of parameters: {num_params}") | |
| loss_fn = torch.nn.BCEWithLogitsLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| total_steps = int(data_args.num_epochs) * 1000 | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, | |
| T_max=total_steps, | |
| eta_min=float((5e-4)/10)) | |
| # missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path)) | |
| # print(f"Missing keys: {missing}") | |
| # print(f"Unexpected keys: {unexpected}") | |
| trainer = Trainer(model=model, optimizer=optimizer, | |
| criterion=loss_fn, output_dim=model_args.output_dim, scaler=None, | |
| scheduler=None, train_dataloader=train_dl, | |
| val_dataloader=val_dl, device=local_rank, | |
| exp_num=datetime_dir, log_path=data_args.log_dir, | |
| range_update=None, | |
| accumulation_step=1, max_iter=np.inf, | |
| exp_name=f"frugal_kan_{exp_num}") | |
| fit_res = trainer.fit(num_epochs=100, device=local_rank, | |
| early_stopping=10, only_p=False, best='loss', conf=True) | |
| output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json' | |
| with open(output_filename, "w") as f: | |
| json.dump(fit_res, f, indent=2) | |
| preds, acc = trainer.predict(test_dl, local_rank) | |
| print(f"Accuracy: {acc}") | |