| import os |
| import csv |
| from fine_tuning.wrappers import NeuroRVQWrapper |
|
|
| class CSVLogger(): |
| def __init__(self, output_dir, ex_id): |
| self.log_dir = os.path.join(output_dir, f"{ex_id}_log") |
| if not os.path.exists(self.log_dir): |
| os.makedirs(self.log_dir) |
| self._files = set() |
| |
| def report_scalar(self, title, series, value, iteration): |
| ''' |
| Mimics clearml report_scalar() function to log values to CSV file |
| ''' |
| if 'train' in series: |
| filepath = os.path.join(self.log_dir, f"{title}_train.csv") |
| else: |
| filepath = os.path.join(self.log_dir, f"{title}_val.csv") |
|
|
| write_header = filepath not in self._files |
|
|
| with open(filepath, mode="a", newline="") as f: |
| writer = csv.writer(f) |
| if 'MEAN' in title: |
| if write_header: |
| writer.writerow(["Series", "Iteration", "Value"]) |
| self._files.add(filepath) |
| writer.writerow([series, iteration, value]) |
| else: |
| if write_header: |
| writer.writerow(["Fold", "Iteration", "Value"]) |
| self._files.add(filepath) |
| writer.writerow([series.split(' ')[-1], iteration, value]) |
|
|
| def get_logger(): |
| logger = CSVLogger("results", 0) |
| return logger |
|
|
| def get_model(ch_names, n_times, n_outputs, args, foundation_model, train_head_only=False): |
| """ |
| Returns: FinetuningWrapper for the specified model |
| """ |
| return NeuroRVQWrapper( |
| n_time=n_times, |
| ch_names=ch_names, |
| n_outputs=n_outputs, |
| train_head_only=train_head_only, |
| args = args, |
| foundation_model = foundation_model |
| ) |
|
|