File size: 1,724 Bytes
090e11e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b82a45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Label/tag processing helpers for multi-label classification."""

from __future__ import annotations

from typing import Dict, Iterable, List, Sequence, Tuple, Union

import torch


def process_tags(tags: Union[str, Sequence[str], None], sep: str = ",") -> List[str]:
    """Convert raw tags to a list of normalized tag strings."""
    if tags is None:
        return []
    if isinstance(tags, str):
        parts = [t.strip() for t in tags.split(sep)]
        return [p for p in parts if p]
    # Sequence[str]
    out: List[str] = []
    for t in tags:
        if t is None:
            continue
        s = str(t).strip()
        if s:
            out.append(s)
    return out


def build_label_mapping(
    df,
    *,
    tags_col: str = "tags",
    sep: str = ",",
) -> Dict[str, int]:
    """Build a tag->index mapping from a dataframe-like object.

    Expects `df[tags_col]` to contain either comma-separated strings or lists.
    """
    tag_set = set()
    for raw in df[tags_col].tolist():
        tag_set.update(process_tags(raw, sep=sep))
    return {tag: i for i, tag in enumerate(sorted(tag_set))}


def create_target_encoding(
    tag_lists: Iterable[Union[str, Sequence[str], None]],
    label_to_idx: Dict[str, int],
    *,
    sep: str = ",",
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """Create a multi-hot target tensor of shape [N, num_labels]."""
    tag_lists = list(tag_lists)
    num_labels = len(label_to_idx)
    y = torch.zeros((len(tag_lists), num_labels), dtype=dtype)
    for i, raw in enumerate(tag_lists):
        for tag in process_tags(raw, sep=sep):
            j = label_to_idx.get(tag)
            if j is not None:
                y[i, j] = 1.0
    return y