File size: 10,427 Bytes
b4b2877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
#!/usr/bin/env python3
"""
Generate coarse-grained annotations by merging consecutive fine-grained segments
into composite actions (8-15s duration) using LLM.

Input:  annotations_v2/ (fine-grained, ~2-3s segments, 11 classes)
Output: annotations_coarse/ (coarse-grained, ~8-15s segments, ~6 classes)

Does NOT modify annotations_v2/.
"""

import os
import json
import re
import time
import glob
import urllib.request
from collections import Counter

INPUT_DIR = "${PULSE_ROOT}/annotations_v2"
OUTPUT_DIR = "${PULSE_ROOT}/annotations_coarse"

API_URL = "https://api.chatanywhere.tech/v1/chat/completions"
API_KEYS = [
    "sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ",
    "sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY",
    "sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs",
    "sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB",
    "sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk",
    "sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH",
    "sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz",
    "sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu",
    "sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz",
    "sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ",
]

SCENE_DESCRIPTIONS = {
    "s1": "办公桌面整理与工作准备",
    "s2": "快递打包发送",
    "s3": "厨房调料整理",
    "s4": "清理餐后桌面",
    "s5": "餐前桌面布置",
    "s6": "商务旅行行李箱打包",
    "s7": "冲泡咖啡/饮品",
    "s8": "晾衣架整理与衣物收纳",
}

COARSE_CATEGORIES = """粗粒度动作类别(共6类):

1. Manipulate - 操作物体(抓取、调整、放置某个物体的完整过程,包含拿起→操作→放下的组合)
2. CleanOrganize - 清洁/整理(擦桌子、理线、整理桌面、叠衣服等持续性整理活动)
3. Transfer - 搬运/传递(将物体从一个位置搬到另一个位置的过程)
4. Assemble - 组装/连接/包装(封箱、贴胶带、盖盖子、插电源、拧瓶盖等需要精细对准的操作)
5. FoodPrep - 食物/饮品准备(倒水、倒调料、搅拌、冲泡等与食物饮品相关的操作)
6. Idle - 空闲/过渡(无明确操作的间隔)
"""

current_key_idx = 0
call_count = 0


def call_llm(prompt, max_tokens=1500, retries=3):
    global current_key_idx, call_count
    for attempt in range(retries * len(API_KEYS)):
        key = API_KEYS[current_key_idx]
        try:
            data = json.dumps({
                "model": "gpt-4o-mini",
                "messages": [{"role": "user", "content": prompt}],
                "max_tokens": max_tokens,
                "temperature": 0.1,
            }).encode()
            req = urllib.request.Request(
                API_URL, data=data,
                headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
            )
            resp = urllib.request.urlopen(req, timeout=30)
            result = json.loads(resp.read())
            call_count += 1
            return result["choices"][0]["message"]["content"]
        except Exception as e:
            err = str(e)
            if any(k in err for k in ["429", "quota", "limit", "402", "403"]):
                current_key_idx = (current_key_idx + 1) % len(API_KEYS)
            else:
                time.sleep(0.5)
                current_key_idx = (current_key_idx + 1) % len(API_KEYS)
    return None


def parse_ts(ts_str):
    """Parse 'MM:SS' to seconds."""
    m = re.match(r'(\d+):(\d+)', ts_str.strip())
    if m:
        return int(m.group(1)) * 60 + int(m.group(2))
    return 0


def format_ts(sec):
    """Format seconds to 'MM:SS'."""
    return f"{sec//60:02d}:{sec%60:02d}"


def merge_segments_with_llm(segments, scene_id):
    """Use LLM to merge fine-grained segments into coarse composite actions."""
    scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")

    # Build segment list
    seg_lines = []
    for i, seg in enumerate(segments):
        label = seg.get("action_label", "Idle")
        seg_lines.append(f"{i+1}. [{seg['timestamp']}] {label}: {seg['task']}")
    seg_text = "\n".join(seg_lines)

    prompt = f"""你是一个动作标注专家。以下是一段"{scene_desc}"录制中的细粒度动作序列(每个2-3秒)。
请将相关的连续动作合并为粗粒度复合动作,每个复合动作持续5-15秒。

合并规则:
- 围绕同一个物体的连续操作合并为一个(如"抓取杯子→调整→放下"合并为一个Manipulate)
- 连续的整理/清洁动作合并
- 合并后的时间范围 = 第一个子动作的开始时间 到 最后一个子动作的结束时间
- 如果中间有短暂Idle(≤3秒),可以包含进去
- 每个复合动作必须从6个类别中选一个

{COARSE_CATEGORIES}

细粒度动作序列:
{seg_text}

请严格按以下JSON格式返回,不要添加任何额外文字:
[{{"timestamp": "MM:SS-MM:SS", "coarse_action": "类别名", "description": "简要描述这段复合动作", "fine_segments": [子动作编号列表]}}]"""

    response = call_llm(prompt, max_tokens=2000)
    if response is None:
        return None

    try:
        match = re.search(r'\[.*\]', response, re.DOTALL)
        if match:
            results = json.loads(match.group())
            valid = []
            for r in results:
                if all(k in r for k in ["timestamp", "coarse_action", "description"]):
                    # Validate category
                    if r["coarse_action"] in {"Manipulate", "CleanOrganize", "Transfer",
                                               "Assemble", "FoodPrep", "Idle"}:
                        valid.append(r)
            return valid
    except (json.JSONDecodeError, KeyError) as e:
        print(f"  Parse error: {e}")
    return None


def process_file(input_path, vol, scenario):
    """Process one annotation file."""
    data = json.load(open(input_path))
    segments = data["segments"]

    if not segments:
        return {"fine_segments": segments, "coarse_segments": []}, 0

    print(f"  Merging {len(segments)} fine segments...")
    coarse = merge_segments_with_llm(segments, scenario)

    if coarse is None:
        # Fallback: simple time-based merging without LLM
        print(f"  LLM failed, using fallback merge")
        coarse = fallback_merge(segments)

    result = {
        "fine_segments": segments,
        "coarse_segments": coarse,
    }
    return result, len(coarse)


def fallback_merge(segments):
    """Simple rule-based merging as fallback."""
    if not segments:
        return []

    coarse = []
    group = [segments[0]]

    for seg in segments[1:]:
        # Parse timestamps
        prev_ts = group[-1]["timestamp"]
        curr_ts = seg["timestamp"]
        m1 = re.match(r'(\d+:\d+)\s*-\s*(\d+:\d+)', prev_ts)
        m2 = re.match(r'(\d+:\d+)\s*-\s*(\d+:\d+)', curr_ts)
        if not m1 or not m2:
            group.append(seg)
            continue

        prev_end = parse_ts(m1.group(2))
        curr_start = parse_ts(m2.group(1))
        gap = curr_start - prev_end

        # Merge if gap ≤ 3s and group duration < 15s
        group_start = parse_ts(re.match(r'(\d+:\d+)', group[0]["timestamp"]).group(1))
        curr_end = parse_ts(m2.group(2))
        group_duration = curr_end - group_start

        if gap <= 3 and group_duration <= 15:
            group.append(seg)
        else:
            # Emit current group
            coarse.append(_emit_group(group))
            group = [seg]

    if group:
        coarse.append(_emit_group(group))

    return coarse


def _emit_group(group):
    """Create a coarse segment from a group of fine segments."""
    m_start = re.match(r'(\d+:\d+)', group[0]["timestamp"])
    m_end = re.match(r'\d+:\d+\s*-\s*(\d+:\d+)', group[-1]["timestamp"])
    start = m_start.group(1) if m_start else "00:00"
    end = m_end.group(1) if m_end else "00:00"

    labels = [seg.get("action_label", "Idle") for seg in group]
    label_counts = Counter(labels)
    dominant = label_counts.most_common(1)[0][0]

    # Map fine label to coarse
    label_map = {
        "Grasp": "Manipulate", "Place": "Manipulate", "Arrange": "CleanOrganize",
        "Wipe": "CleanOrganize", "Fold": "CleanOrganize", "Transport": "Transfer",
        "OpenClose": "Assemble", "TearCut": "Assemble",
        "Pour": "FoodPrep", "Stir": "FoodPrep", "Idle": "Idle",
    }
    coarse_label = label_map.get(dominant, "Manipulate")

    tasks = [seg["task"] for seg in group]
    desc = tasks[0] if len(tasks) == 1 else f"{tasks[0]}...{tasks[-1]}"

    return {
        "timestamp": f"{start}-{end}",
        "coarse_action": coarse_label,
        "description": desc[:80],
        "fine_segments": list(range(1, len(group) + 1)),
    }


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    total_fine = 0
    total_coarse = 0
    total_files = 0
    coarse_labels = Counter()

    for vol_dir in sorted(glob.glob(f"{INPUT_DIR}/v*")):
        vol = os.path.basename(vol_dir)
        out_dir = os.path.join(OUTPUT_DIR, vol)
        os.makedirs(out_dir, exist_ok=True)

        for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")):
            scenario = os.path.basename(ann_file).replace(".json", "")
            print(f"[{vol}/{scenario}]", flush=True)

            result, n_coarse = process_file(ann_file, vol, scenario)

            out_path = os.path.join(out_dir, f"{scenario}.json")
            with open(out_path, "w", encoding="utf-8") as f:
                json.dump(result, f, ensure_ascii=False, indent=2)

            n_fine = len(result["fine_segments"])
            total_fine += n_fine
            total_coarse += n_coarse
            total_files += 1

            for seg in result["coarse_segments"]:
                coarse_labels[seg["coarse_action"]] += 1

            print(f"  {n_fine} fine → {n_coarse} coarse segments", flush=True)

    print(f"\n{'='*60}")
    print(f"Total: {total_files} files")
    print(f"  Fine segments:   {total_fine}")
    print(f"  Coarse segments: {total_coarse}")
    print(f"  Compression:     {total_fine/max(total_coarse,1):.1f}x")
    print(f"  API calls:       {call_count}")

    print(f"\n  Coarse label distribution:")
    for label, count in coarse_labels.most_common():
        print(f"    {label:<20} {count:>5} ({count/max(total_coarse,1)*100:.1f}%)")

    print(f"\n  Output: {OUTPUT_DIR}")


if __name__ == "__main__":
    main()