|
|
"""Label/tag processing helpers for multi-label classification.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Dict, Iterable, List, Sequence, Tuple, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def process_tags(tags: Union[str, Sequence[str], None], sep: str = ",") -> List[str]: |
|
|
"""Convert raw tags to a list of normalized tag strings.""" |
|
|
if tags is None: |
|
|
return [] |
|
|
if isinstance(tags, str): |
|
|
parts = [t.strip() for t in tags.split(sep)] |
|
|
return [p for p in parts if p] |
|
|
|
|
|
out: List[str] = [] |
|
|
for t in tags: |
|
|
if t is None: |
|
|
continue |
|
|
s = str(t).strip() |
|
|
if s: |
|
|
out.append(s) |
|
|
return out |
|
|
|
|
|
|
|
|
def build_label_mapping( |
|
|
df, |
|
|
*, |
|
|
tags_col: str = "tags", |
|
|
sep: str = ",", |
|
|
) -> Dict[str, int]: |
|
|
"""Build a tag->index mapping from a dataframe-like object. |
|
|
|
|
|
Expects `df[tags_col]` to contain either comma-separated strings or lists. |
|
|
""" |
|
|
tag_set = set() |
|
|
for raw in df[tags_col].tolist(): |
|
|
tag_set.update(process_tags(raw, sep=sep)) |
|
|
return {tag: i for i, tag in enumerate(sorted(tag_set))} |
|
|
|
|
|
|
|
|
def create_target_encoding( |
|
|
tag_lists: Iterable[Union[str, Sequence[str], None]], |
|
|
label_to_idx: Dict[str, int], |
|
|
*, |
|
|
sep: str = ",", |
|
|
dtype: torch.dtype = torch.float32, |
|
|
) -> torch.Tensor: |
|
|
"""Create a multi-hot target tensor of shape [N, num_labels].""" |
|
|
tag_lists = list(tag_lists) |
|
|
num_labels = len(label_to_idx) |
|
|
y = torch.zeros((len(tag_lists), num_labels), dtype=dtype) |
|
|
for i, raw in enumerate(tag_lists): |
|
|
for tag in process_tags(raw, sep=sep): |
|
|
j = label_to_idx.get(tag) |
|
|
if j is not None: |
|
|
y[i, j] = 1.0 |
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
|