File size: 13,950 Bytes
b4b2877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
#!/usr/bin/env python3
"""
Re-annotate action segments using LLM (GPT-4o-mini).
1. Re-classify existing segments with better accuracy
2. Infer actions in unlabeled gaps based on context (scene, surrounding actions)
3. Output improved annotations with higher coverage
"""

import os
import sys
import json
import re
import time
import copy
import glob
import urllib.request
from collections import Counter

ANN_DIR = "${PULSE_ROOT}/annotations_by_scene"
OUTPUT_DIR = "${PULSE_ROOT}/annotations_v2"
DATASET_DIR = "${PULSE_ROOT}/dataset"

API_URL = "https://api.chatanywhere.tech/v1/chat/completions"
API_KEYS = [
    "sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ",
    "sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY",
    "sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs",
    "sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB",
    "sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk",
    "sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH",
    "sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz",
    "sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu",
    "sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz",
    "sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ",
]

SCENE_DESCRIPTIONS = {
    "s1": "办公桌面整理与工作准备(整理文件、电源线、鼠标、笔记本电脑等)",
    "s2": "快递打包发送(折叠纸箱、放入物品、封箱、贴标签等)",
    "s3": "厨房调料整理(拿取调料瓶、倒调料、拧瓶盖、擦拭等)",
    "s4": "清理餐后桌面(收碗碟、擦桌子、整理餐具、倒残渣等)",
    "s5": "餐前桌面布置(铺桌布、摆放餐具碗碟、放杯子等)",
    "s6": "商务旅行行李箱打包(折叠衣物、放入行李箱、整理物品等)",
    "s7": "冲泡咖啡/饮品(取杯子、放咖啡粉/茶包、倒热水、搅拌等)",
    "s8": "晾衣架整理与衣物收纳(取衣架、挂衣服、折叠衣物等)",
}

ACTION_CATEGORIES = """动作类别定义(共11类):

1. Grasp - 抓取/拿起物体(手从无接触到接触并握住物体)
2. Place - 放置/放下物体(将物体放到某个位置并释放)
3. Pour - 倾倒/注入液体或颗粒(倒水、倒调料、倒咖啡粉等)
4. Wipe - 擦拭/清洁表面(用抹布或手擦桌面、瓶身等)
5. Fold - 折叠/卷起(折衣服、折桌布、折纸箱等)
6. OpenClose - 打开/关闭/旋开/旋紧(开盒子、拧瓶盖、拉拉链、合箱盖等)
7. Stir - 搅拌(搅拌咖啡、搅拌饮品等)
8. TearCut - 撕/剪/粘贴(撕胶带、剪快递单、贴标签等)
9. Arrange - 整理/摆放/调整位置(摆餐具、整理文件、调整物品位置、理线等)
10. Transport - 搬运/移动物体到较远位置(把包裹搬到架子、把碗端到水槽等)
11. Idle - 空闲/过渡/无明确操作(双手无目的性动作、等待、观察等)

注意:
- 只有真正没有任何手部操作时才标Idle
- "调整姿态"、"检查物体"等属于Arrange
- "插入"、"装入"等属于Place
- "提起并移动"如果距离短属于Grasp,距离远属于Transport
"""

current_key_idx = 0
call_count = 0


def call_llm(prompt, max_tokens=1000, retries=3):
    """Call LLM API with automatic key rotation."""
    global current_key_idx, call_count

    for attempt in range(retries * len(API_KEYS)):
        key = API_KEYS[current_key_idx]
        try:
            data = json.dumps({
                "model": "gpt-4o-mini",
                "messages": [{"role": "user", "content": prompt}],
                "max_tokens": max_tokens,
                "temperature": 0.1,
            }).encode()
            req = urllib.request.Request(
                API_URL, data=data,
                headers={
                    "Content-Type": "application/json",
                    "Authorization": f"Bearer {key}",
                }
            )
            resp = urllib.request.urlopen(req, timeout=30)
            result = json.loads(resp.read())
            call_count += 1
            return result["choices"][0]["message"]["content"]
        except Exception as e:
            err = str(e)
            if "429" in err or "quota" in err or "limit" in err or "402" in err:
                # Key exhausted, rotate
                print(f"  Key {current_key_idx+1} exhausted, rotating...")
                current_key_idx = (current_key_idx + 1) % len(API_KEYS)
            elif "timeout" in err.lower():
                time.sleep(1)
            else:
                print(f"  API error: {err[:100]}")
                current_key_idx = (current_key_idx + 1) % len(API_KEYS)
                time.sleep(0.5)

    print("  WARNING: All API keys failed!")
    return None


def reclassify_segments(segments, scene_id):
    """Use LLM to reclassify all segments in a recording."""
    scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")

    # Build segment list for prompt
    seg_list = []
    for i, seg in enumerate(segments):
        seg_list.append(f"{i+1}. [{seg['timestamp']}] {seg['task']}")
    seg_text = "\n".join(seg_list)

    prompt = f"""你是一个人体动作标注专家。请为以下每个动作片段分配一个动作类别。

场景:{scene_desc}

{ACTION_CATEGORIES}

动作片段列表:
{seg_text}

请严格按以下JSON格式返回,不要添加任何额外文字:
[{{"id": 1, "action": "类别名"}}, {{"id": 2, "action": "类别名"}}, ...]

每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle"""

    response = call_llm(prompt, max_tokens=len(segments) * 40)
    if response is None:
        return None

    # Parse response
    try:
        # Extract JSON from response
        match = re.search(r'\[.*\]', response, re.DOTALL)
        if match:
            results = json.loads(match.group())
            return {r["id"]: r["action"] for r in results}
    except (json.JSONDecodeError, KeyError) as e:
        print(f"  Parse error: {e}, response: {response[:200]}")
    return None


def infer_gap_actions(scene_id, before_seg, after_seg, gap_start, gap_end):
    """Use LLM to infer what actions likely happened in an unlabeled gap."""
    scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
    gap_duration = gap_end - gap_start

    before_text = f"[{before_seg['timestamp']}] {before_seg['task']}" if before_seg else "(录制开始)"
    after_text = f"[{after_seg['timestamp']}] {after_seg['task']}" if after_seg else "(录制结束)"

    prompt = f"""你是一个人体动作标注专家。在一段日常活动录制中,有一段时间没有被标注。请根据场景和前后动作推断这段时间内最可能发生的动作。

场景:{scene_desc}
未标注时间段:{gap_start//60:02d}:{gap_start%60:02d} - {gap_end//60:02d}:{gap_end%60:02d}(共{gap_duration}秒)
前一个标注动作:{before_text}
后一个标注动作:{after_text}

{ACTION_CATEGORIES}

请推断这段时间内可能发生的动作序列。每个动作段落2-4秒,时间用MM:SS格式。
如果确实是空闲等待,标注为Idle。

严格按以下JSON格式返回,不要添加任何额外文字:
[{{"timestamp": "MM:SS-MM:SS", "task": "动作描述", "action": "类别名"}}]

每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle"""

    response = call_llm(prompt, max_tokens=500)
    if response is None:
        return []

    try:
        match = re.search(r'\[.*\]', response, re.DOTALL)
        if match:
            results = json.loads(match.group())
            # Validate timestamps
            valid = []
            for r in results:
                if "timestamp" in r and "action" in r and "task" in r:
                    ts_match = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', r["timestamp"])
                    if ts_match:
                        s = int(ts_match.group(1))*60 + int(ts_match.group(2))
                        e = int(ts_match.group(3))*60 + int(ts_match.group(4))
                        if gap_start <= s < e <= gap_end:
                            valid.append(r)
            return valid
    except (json.JSONDecodeError, KeyError) as e:
        print(f"  Parse error: {e}")
    return []


def get_recording_duration(vol, scenario):
    """Get total recording duration in seconds."""
    meta_path = os.path.join(DATASET_DIR, vol, scenario, "alignment_metadata.json")
    if os.path.exists(meta_path):
        meta = json.load(open(meta_path))
        if "aligned_length_sec" in meta:
            return meta["aligned_length_sec"]
        if "aligned_length_frames" in meta:
            return meta["aligned_length_frames"] / 100.0
    return None


def process_one_file(ann_path, vol, scenario):
    """Process one annotation file: reclassify + fill gaps."""
    data = json.load(open(ann_path))
    segments = data["segments"]

    if not segments:
        return data, {"reclassified": 0, "gaps_filled": 0}

    # Step 1: Reclassify existing segments
    print(f"  Reclassifying {len(segments)} segments...")
    classifications = reclassify_segments(segments, scenario)

    if classifications:
        for i, seg in enumerate(segments):
            action = classifications.get(i + 1)
            if action and action in {"Grasp", "Place", "Pour", "Wipe", "Fold",
                                      "OpenClose", "Stir", "TearCut", "Arrange",
                                      "Transport", "Idle"}:
                seg["action_label"] = action
            else:
                seg["action_label"] = "Idle"
    else:
        # Fallback: keep without label
        for seg in segments:
            seg["action_label"] = "Idle"

    reclassified = sum(1 for s in segments if "action_label" in s)

    # Step 2: Find and fill gaps ≥ 3 seconds
    # Parse all timestamps
    parsed = []
    for seg in segments:
        m = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', seg["timestamp"])
        if m:
            s = int(m.group(1))*60 + int(m.group(2))
            e = int(m.group(3))*60 + int(m.group(4))
            parsed.append((s, e, seg))
    parsed.sort()

    total_dur = get_recording_duration(vol, scenario)

    new_segments = []
    gaps_filled = 0

    for i in range(len(parsed)):
        new_segments.append(parsed[i][2])

        # Check gap after this segment
        if i < len(parsed) - 1:
            gap_start = parsed[i][1]
            gap_end = parsed[i + 1][0]
        elif total_dur:
            gap_start = parsed[i][1]
            gap_end = int(total_dur)
        else:
            continue

        gap_duration = gap_end - gap_start
        if gap_duration >= 3:
            before_seg = parsed[i][2]
            after_seg = parsed[i + 1][2] if i < len(parsed) - 1 else None

            print(f"    Filling gap {gap_start}s-{gap_end}s ({gap_duration}s)...")
            inferred = infer_gap_actions(scenario, before_seg, after_seg, gap_start, gap_end)

            for inf in inferred:
                new_seg = {
                    "timestamp": inf["timestamp"],
                    "task": inf["task"],
                    "action_label": inf["action"],
                    "source": "llm_inferred",
                    "left_hand": "",
                    "right_hand": "",
                    "bimanual_interaction": "",
                    "objects": [],
                }
                new_segments.append(new_seg)
                gaps_filled += 1

    # Also check gap at the beginning
    if parsed and parsed[0][0] >= 3:
        print(f"    Filling start gap 0s-{parsed[0][0]}s...")
        inferred = infer_gap_actions(scenario, None, parsed[0][2], 0, parsed[0][0])
        for inf in inferred:
            new_seg = {
                "timestamp": inf["timestamp"],
                "task": inf["task"],
                "action_label": inf["action"],
                "source": "llm_inferred",
                "left_hand": "",
                "right_hand": "",
                "bimanual_interaction": "",
                "objects": [],
            }
            new_segments.insert(0, new_seg)
            gaps_filled += 1

    # Sort by timestamp
    def sort_key(seg):
        m = re.match(r'(\d+):(\d+)', seg["timestamp"])
        return int(m.group(1))*60 + int(m.group(2)) if m else 0
    new_segments.sort(key=sort_key)

    result = copy.deepcopy(data)
    result["segments"] = new_segments

    return result, {"reclassified": reclassified, "gaps_filled": gaps_filled}


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    total_reclassified = 0
    total_gaps_filled = 0
    total_files = 0

    for vol_dir in sorted(glob.glob(f"{ANN_DIR}/v*")):
        vol = os.path.basename(vol_dir)
        out_vol_dir = os.path.join(OUTPUT_DIR, vol)
        os.makedirs(out_vol_dir, exist_ok=True)

        for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")):
            scenario = os.path.basename(ann_file).replace(".json", "")
            print(f"\n[{vol}/{scenario}]", flush=True)

            result, stats = process_one_file(ann_file, vol, scenario)

            # Save
            out_path = os.path.join(out_vol_dir, f"{scenario}.json")
            with open(out_path, "w", encoding="utf-8") as f:
                json.dump(result, f, ensure_ascii=False, indent=2)

            total_reclassified += stats["reclassified"]
            total_gaps_filled += stats["gaps_filled"]
            total_files += 1

            print(f"  Done: {stats['reclassified']} reclassified, {stats['gaps_filled']} gaps filled",
                  flush=True)

    print(f"\n{'='*60}")
    print(f"Total: {total_files} files processed")
    print(f"  Reclassified: {total_reclassified} segments")
    print(f"  Gap-filled:   {total_gaps_filled} new segments")
    print(f"  API calls:    {call_count}")
    print(f"  Output:       {OUTPUT_DIR}")


if __name__ == "__main__":
    main()