File size: 1,742 Bytes
27f9443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import csv
from fine_tuning.wrappers import NeuroRVQWrapper

class CSVLogger():
    def __init__(self, output_dir, ex_id):
        self.log_dir = os.path.join(output_dir, f"{ex_id}_log")
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        self._files = set()
    
    def report_scalar(self, title, series, value, iteration):
        '''
        Mimics clearml report_scalar() function to log values to CSV file
        '''
        if 'train' in series:
            filepath = os.path.join(self.log_dir, f"{title}_train.csv")
        else:
            filepath = os.path.join(self.log_dir, f"{title}_val.csv")

        write_header = filepath not in self._files

        with open(filepath, mode="a", newline="") as f:
            writer = csv.writer(f)
            if 'MEAN' in title:
                if write_header:
                    writer.writerow(["Series", "Iteration", "Value"])
                    self._files.add(filepath)
                writer.writerow([series, iteration, value])
            else:
                if write_header:
                    writer.writerow(["Fold", "Iteration", "Value"])
                    self._files.add(filepath)
                writer.writerow([series.split(' ')[-1], iteration, value])

def get_logger():
    logger = CSVLogger("results", 0)
    return logger

def get_model(ch_names, n_times, n_outputs, args, foundation_model, train_head_only=False):
    """
    Returns: FinetuningWrapper for the specified model
    """
    return NeuroRVQWrapper(
        n_time=n_times,
        ch_names=ch_names,
        n_outputs=n_outputs,
        train_head_only=train_head_only,
        args = args,
        foundation_model = foundation_model
    )