File size: 7,090 Bytes
b4b2877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
Taxonomy for T10 Next-Action Triplet Prediction on DailyAct-5M.

Design decisions (fixed per user):
  * VERB_FINE:      17 primitives observed in annotations_v3 (Strategy: keep all)
  * VERB_COMPOSITE: 6 classes by manual rollup
  * NOUN:           keep nouns with >=50 segments (Strategy A: drop others entirely)
  * HAND:           3 classes {left, right, both}

The noun list is *frozen* in taxonomy_v3.json so class indices stay stable even
as more annotations are added. Regenerate with `build_taxonomy.py` when you are
ready to lock the final list.
"""

from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Dict, List, Optional

# ---------------------------------------------------------------------------
# Verb (fine, 17 classes)
# ---------------------------------------------------------------------------

VERB_FINE: List[str] = [
    "grasp",     "move",      "place",   "adjust",
    "pick_up",   "hold",      "pull",    "put_down",
    "close",     "release",   "rotate",  "open",
    "insert",    "push",      "align",   "remove",
    "stabilize",
]
NUM_VERB_FINE = len(VERB_FINE)  # 17
VERB_FINE_IDX: Dict[str, int] = {v: i for i, v in enumerate(VERB_FINE)}


# ---------------------------------------------------------------------------
# Verb (composite, 6 classes) — manual rollup
# ---------------------------------------------------------------------------

VERB_COMPOSITE: List[str] = [
    "grasp-family",   # grasp, pick_up, hold
    "place-family",   # place, put_down
    "transport",      # move, pull, push
    "adjust",         # adjust, align, stabilize
    "state-change",   # open, close, rotate, insert, remove
    "release",        # release
]
NUM_VERB_COMPOSITE = len(VERB_COMPOSITE)  # 6
VERB_COMPOSITE_IDX: Dict[str, int] = {v: i for i, v in enumerate(VERB_COMPOSITE)}

_FINE_TO_COMPOSITE: Dict[str, str] = {
    "grasp":      "grasp-family",
    "pick_up":    "grasp-family",
    "hold":       "grasp-family",
    "place":      "place-family",
    "put_down":   "place-family",
    "move":       "transport",
    "pull":       "transport",
    "push":       "transport",
    "adjust":     "adjust",
    "align":      "adjust",
    "stabilize":  "adjust",
    "open":       "state-change",
    "close":      "state-change",
    "rotate":     "state-change",
    "insert":     "state-change",
    "remove":     "state-change",
    "release":    "release",
}
assert set(_FINE_TO_COMPOSITE.keys()) == set(VERB_FINE), (
    "Verb rollup must cover every fine verb"
)


def verb_fine_to_composite_idx(verb_fine: str) -> int:
    """Map a fine verb string -> composite class index (0..5)."""
    composite = _FINE_TO_COMPOSITE[verb_fine]
    return VERB_COMPOSITE_IDX[composite]


# ---------------------------------------------------------------------------
# Hand (3 classes)
# ---------------------------------------------------------------------------

HAND: List[str] = ["left", "right", "both"]
NUM_HAND = len(HAND)
HAND_IDX: Dict[str, int] = {h: i for i, h in enumerate(HAND)}


# ---------------------------------------------------------------------------
# Noun — canonical merge table (handles mild annotator inconsistency)
# ---------------------------------------------------------------------------

NOUN_CANONICAL: Dict[str, str] = {
    "折叠雨伞": "folding umbrella",
    "mouse":    "wired mouse",
}


def canonical_noun(n: str) -> str:
    """Map raw noun string -> canonical name (handles CJK leak + aliases)."""
    return NOUN_CANONICAL.get(n, n)


# ---------------------------------------------------------------------------
# Noun list — frozen per-release, loaded from JSON for reproducibility
# ---------------------------------------------------------------------------

TAXONOMY_FROZEN_PATH = Path(__file__).parent / "taxonomy_v3.json"
NOUN_KEEP_THRESHOLD = 50


def _load_frozen() -> Optional[dict]:
    if not TAXONOMY_FROZEN_PATH.exists():
        return None
    with open(TAXONOMY_FROZEN_PATH) as f:
        return json.load(f)


_frozen = _load_frozen()

if _frozen is not None:
    NOUN: List[str] = list(_frozen["nouns"])
    FROZEN_ANNOTATION_COUNT: int = _frozen.get("annotation_file_count", -1)
    FROZEN_SEGMENT_COUNT: int = _frozen.get("total_segments", -1)
else:
    # Bootstrap list from the initial 167-file scan (Apr 24). Overwritten when
    # build_taxonomy.py is run against the final 283-file set.
    NOUN = [
        "towel", "sealed jar", "box", "tablecloth", "pot", "tape", "rice bowl",
        "pants", "spoon", "marker", "cloth", "plate", "laptop",
        "toothbrush case", "tea canister", "hanger", "wired keyboard",
        "wired mouse", "laptop power adapter", "seasoning bottle", "mug",
        "seasoning jar", "tray", "document", "coat", "tea bag", "water cup",
        "shirt",
    ]
    FROZEN_ANNOTATION_COUNT = 167
    FROZEN_SEGMENT_COUNT = 4140

NUM_NOUN = len(NOUN)
NOUN_IDX: Dict[str, int] = {n: i for i, n in enumerate(NOUN)}


def noun_to_idx(raw_noun: str) -> Optional[int]:
    """Map raw noun -> class index, or None if noun should be dropped (Strategy A)."""
    canon = canonical_noun(raw_noun)
    return NOUN_IDX.get(canon, None)


# ---------------------------------------------------------------------------
# One-shot classify
# ---------------------------------------------------------------------------

def classify_segment(action_annotation: dict) -> Optional[dict]:
    """Convert a raw annotation dict into triplet label indices.

    Returns None if any field is missing or the noun is not in the kept list
    (Strategy A: drop the segment).
    """
    verb = action_annotation.get("action_name")
    noun = action_annotation.get("object_name")
    hand = action_annotation.get("hand_type")
    if not (verb and noun and hand):
        return None
    if verb not in VERB_FINE_IDX:
        return None
    if hand not in HAND_IDX:
        return None
    n_idx = noun_to_idx(noun)
    if n_idx is None:
        return None
    v_fine_idx = VERB_FINE_IDX[verb]
    return {
        "verb_fine":      v_fine_idx,
        "verb_composite": verb_fine_to_composite_idx(verb),
        "noun":           n_idx,
        "hand":           HAND_IDX[hand],
    }


# ---------------------------------------------------------------------------
# Summary for logging / sanity
# ---------------------------------------------------------------------------

def summary() -> str:
    lines = []
    lines.append(f"Verb fine      : {NUM_VERB_FINE}")
    lines.append(f"Verb composite : {NUM_VERB_COMPOSITE}")
    lines.append(f"Noun           : {NUM_NOUN}  (kept at >= {NOUN_KEEP_THRESHOLD} segments)")
    lines.append(f"Hand           : {NUM_HAND}")
    lines.append(f"Frozen from    : {FROZEN_ANNOTATION_COUNT} files, "
                 f"{FROZEN_SEGMENT_COUNT} segments")
    return "\n".join(lines)


if __name__ == "__main__":
    print(summary())
    print()
    print("Verb fine list:", VERB_FINE)
    print("Composite:    ", VERB_COMPOSITE)
    print("Noun list:    ", NOUN)
    print("Hand list:    ", HAND)