File size: 5,953 Bytes
a35137b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import dataclasses
from collections import defaultdict
import pandas as pd
import torch
from typing import Dict, List, Optional, Union

from barista.data.dataframe_wrapper import DataframeWrapper
from barista.data.metadata_spatial_groups import (
    MetadataSpatialGroupRow,
    MetadataSpatialGroups,
)


@dataclasses.dataclass
class MetadataRow:
    dataset: str
    subject: str
    session: str
    subject_session: str
    experiment: str
    d_input: int
    d_data: torch.Size
    split: str
    path: str
    filename: str
    processing_str: str
    seq_len: int
    label: Optional[float]


class Metadata(DataframeWrapper):
    """
    Metadata class to keep track of all segment meta information.
    """

    def __init__(self, df=None, load_path=None, spatial_group_df=None):
        if df is None:
            assert spatial_group_df is None

        super().__init__(df, load_path)

        self._spatial_groups = None
        if load_path is not None:
            try:
                self._spatial_groups = MetadataSpatialGroups(
                    load_path=self._get_spatial_group_path(load_path)
                )
            except FileNotFoundError:
                pass
        elif spatial_group_df is not None:
            self._spatial_groups = MetadataSpatialGroups(df=spatial_group_df)

    def _get_spatial_group_path(self, path: str) -> str:
        suffix = ".csv"
        new_path = path[: -len(suffix)]
        spatial_path = f"{new_path}_spatial_groups{suffix}"
        return spatial_path

    def save(self, path: str) -> None:
        super().save(path)
        self._spatial_groups.save(self._get_spatial_group_path(path))

    @classmethod
    def merge(
        cls,
        metadatas: List["Metadata"],
        drop_duplicate: bool = False,
        merge_columns: Union[str, List[str], None] = None,
        keep="first",
    ) -> "Metadata":
        new_metadata = super().merge(metadatas, drop_duplicate, merge_columns, keep)

        # Add spatial groups
        spatial_groups = [m._spatial_groups for m in metadatas]
        merged_spatial_groups = MetadataSpatialGroups.merge(
            spatial_groups,
            drop_duplicate=True,
            merge_columns=[
                "dataset",
                "subject_session",
                "name",
            ],
        )
        new_metadata._spatial_groups = merged_spatial_groups
        return new_metadata

    def get_subject_session_d_input(self) -> dict:
        return self._get_column_mapping_dict_from_dataframe(
            key_col="subject_session",
            value_col="d_input",
        )

    def get_subjects(self) -> dict:
        return self.get_unique_values_in_col("subject")

    def _shape_str_to_list(self, value) -> tuple:
        if not isinstance(value, str):
            return value
        return [int(a) for a in value.split(",")]

    def get_subject_session_full_d_data(self) -> Dict[str, List[int]]:
        """
        Returns a dict containing subject_session to data shape
        """
        my_dict = self._get_column_mapping_dict_from_dataframe(
            key_col="subject_session",
            value_col="d_data",
        )
        return {k: self._shape_str_to_list(v) for k, v in my_dict.items()}


    def get_labels_count_summary(self) -> dict:
        splits = self.get_unique_values_in_col("split")
        labels = self.get_unique_values_in_col("label")
        
        labels_count = defaultdict(dict)
        for split in splits:
            for label in labels:
                count = len(
                    self.get_indices_matching_cols_values(
                        ["split", "label"],
                        [split, label],
                    )
                )
                labels_count[split][label] = count
        return labels_count

    def get_summary_str(self) -> str:
        subjects = self.get_unique_values_in_col("subject")
        labels_count = self.get_labels_count_summary()
        
        summary_str = f"Metadata for {len(subjects)} subjects ({subjects})"

        for split, labels in labels_count.items():
            for label, count in labels.items():
                summary_str += f", {count} {split} segments with label {label}"

        return summary_str

    ########################### spatial group related ###########################  

    def add_spatial_group(self, spatial_group_row: MetadataSpatialGroupRow):
        """
        Add (or overwrite) the spatial group
        """
        self._spatial_groups.remove_spatial_group(
            spatial_group_row.subject_session, spatial_group_row.name
        )
        self._spatial_groups.concat(pd.DataFrame([spatial_group_row]))

    def get_spatial_grouping(
        self, subject_session: str, name: str
    ) -> Optional[MetadataSpatialGroupRow]:
        """
        Return spatial grouping information for spatial grouping `name` and subject_session `subject_session`'s.
        
        Spatial grouping is MetadataSpatialGroupRow which the most important property is group_components 
        which is a list of tuples that contains group info for each channel of the data,
        and group_ids which is a list of integer that specify which group each channel belongs to.
        """

        return self._spatial_groups.get_spatial_grouping(subject_session, name)

    def get_spatial_grouping_id_hashmap(self, name: str) -> Dict[str, List[int]]:
        """
        Return spatial grouping dictionary which maps each subject_session to list of group ids which is a list of 
        length channels specifying which group each channel belongs to.
        
        # NOTE Don't use during forward because of the copy
        """
        temp_copy = self._spatial_groups.copy()
        temp_copy.reduce_based_on_col_value(col_name="name", value=name, keep=True)
        return temp_copy._get_column_mapping_dict_from_dataframe(
            "subject_session", "group_ids"
        )