File size: 1,724 Bytes
090e11e 1b82a45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
"""Label/tag processing helpers for multi-label classification."""
from __future__ import annotations
from typing import Dict, Iterable, List, Sequence, Tuple, Union
import torch
def process_tags(tags: Union[str, Sequence[str], None], sep: str = ",") -> List[str]:
"""Convert raw tags to a list of normalized tag strings."""
if tags is None:
return []
if isinstance(tags, str):
parts = [t.strip() for t in tags.split(sep)]
return [p for p in parts if p]
# Sequence[str]
out: List[str] = []
for t in tags:
if t is None:
continue
s = str(t).strip()
if s:
out.append(s)
return out
def build_label_mapping(
df,
*,
tags_col: str = "tags",
sep: str = ",",
) -> Dict[str, int]:
"""Build a tag->index mapping from a dataframe-like object.
Expects `df[tags_col]` to contain either comma-separated strings or lists.
"""
tag_set = set()
for raw in df[tags_col].tolist():
tag_set.update(process_tags(raw, sep=sep))
return {tag: i for i, tag in enumerate(sorted(tag_set))}
def create_target_encoding(
tag_lists: Iterable[Union[str, Sequence[str], None]],
label_to_idx: Dict[str, int],
*,
sep: str = ",",
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Create a multi-hot target tensor of shape [N, num_labels]."""
tag_lists = list(tag_lists)
num_labels = len(label_to_idx)
y = torch.zeros((len(tag_lists), num_labels), dtype=dtype)
for i, raw in enumerate(tag_lists):
for tag in process_tags(raw, sep=sep):
j = label_to_idx.get(tag)
if j is not None:
y[i, j] = 1.0
return y
|