#!/usr/bin/env python3 """ Re-annotate action segments using LLM (GPT-4o-mini). 1. Re-classify existing segments with better accuracy 2. Infer actions in unlabeled gaps based on context (scene, surrounding actions) 3. Output improved annotations with higher coverage """ import os import sys import json import re import time import copy import glob import urllib.request from collections import Counter ANN_DIR = "${PULSE_ROOT}/annotations_by_scene" OUTPUT_DIR = "${PULSE_ROOT}/annotations_v2" DATASET_DIR = "${PULSE_ROOT}/dataset" API_URL = "https://api.chatanywhere.tech/v1/chat/completions" API_KEYS = [ "sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ", "sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY", "sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs", "sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB", "sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk", "sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH", "sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz", "sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu", "sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz", "sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ", ] SCENE_DESCRIPTIONS = { "s1": "办公桌面整理与工作准备(整理文件、电源线、鼠标、笔记本电脑等)", "s2": "快递打包发送(折叠纸箱、放入物品、封箱、贴标签等)", "s3": "厨房调料整理(拿取调料瓶、倒调料、拧瓶盖、擦拭等)", "s4": "清理餐后桌面(收碗碟、擦桌子、整理餐具、倒残渣等)", "s5": "餐前桌面布置(铺桌布、摆放餐具碗碟、放杯子等)", "s6": "商务旅行行李箱打包(折叠衣物、放入行李箱、整理物品等)", "s7": "冲泡咖啡/饮品(取杯子、放咖啡粉/茶包、倒热水、搅拌等)", "s8": "晾衣架整理与衣物收纳(取衣架、挂衣服、折叠衣物等)", } ACTION_CATEGORIES = """动作类别定义(共11类): 1. Grasp - 抓取/拿起物体(手从无接触到接触并握住物体) 2. Place - 放置/放下物体(将物体放到某个位置并释放) 3. Pour - 倾倒/注入液体或颗粒(倒水、倒调料、倒咖啡粉等) 4. Wipe - 擦拭/清洁表面(用抹布或手擦桌面、瓶身等) 5. Fold - 折叠/卷起(折衣服、折桌布、折纸箱等) 6. OpenClose - 打开/关闭/旋开/旋紧(开盒子、拧瓶盖、拉拉链、合箱盖等) 7. Stir - 搅拌(搅拌咖啡、搅拌饮品等) 8. TearCut - 撕/剪/粘贴(撕胶带、剪快递单、贴标签等) 9. Arrange - 整理/摆放/调整位置(摆餐具、整理文件、调整物品位置、理线等) 10. Transport - 搬运/移动物体到较远位置(把包裹搬到架子、把碗端到水槽等) 11. Idle - 空闲/过渡/无明确操作(双手无目的性动作、等待、观察等) 注意: - 只有真正没有任何手部操作时才标Idle - "调整姿态"、"检查物体"等属于Arrange - "插入"、"装入"等属于Place - "提起并移动"如果距离短属于Grasp,距离远属于Transport """ current_key_idx = 0 call_count = 0 def call_llm(prompt, max_tokens=1000, retries=3): """Call LLM API with automatic key rotation.""" global current_key_idx, call_count for attempt in range(retries * len(API_KEYS)): key = API_KEYS[current_key_idx] try: data = json.dumps({ "model": "gpt-4o-mini", "messages": [{"role": "user", "content": prompt}], "max_tokens": max_tokens, "temperature": 0.1, }).encode() req = urllib.request.Request( API_URL, data=data, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {key}", } ) resp = urllib.request.urlopen(req, timeout=30) result = json.loads(resp.read()) call_count += 1 return result["choices"][0]["message"]["content"] except Exception as e: err = str(e) if "429" in err or "quota" in err or "limit" in err or "402" in err: # Key exhausted, rotate print(f" Key {current_key_idx+1} exhausted, rotating...") current_key_idx = (current_key_idx + 1) % len(API_KEYS) elif "timeout" in err.lower(): time.sleep(1) else: print(f" API error: {err[:100]}") current_key_idx = (current_key_idx + 1) % len(API_KEYS) time.sleep(0.5) print(" WARNING: All API keys failed!") return None def reclassify_segments(segments, scene_id): """Use LLM to reclassify all segments in a recording.""" scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动") # Build segment list for prompt seg_list = [] for i, seg in enumerate(segments): seg_list.append(f"{i+1}. [{seg['timestamp']}] {seg['task']}") seg_text = "\n".join(seg_list) prompt = f"""你是一个人体动作标注专家。请为以下每个动作片段分配一个动作类别。 场景:{scene_desc} {ACTION_CATEGORIES} 动作片段列表: {seg_text} 请严格按以下JSON格式返回,不要添加任何额外文字: [{{"id": 1, "action": "类别名"}}, {{"id": 2, "action": "类别名"}}, ...] 每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle""" response = call_llm(prompt, max_tokens=len(segments) * 40) if response is None: return None # Parse response try: # Extract JSON from response match = re.search(r'\[.*\]', response, re.DOTALL) if match: results = json.loads(match.group()) return {r["id"]: r["action"] for r in results} except (json.JSONDecodeError, KeyError) as e: print(f" Parse error: {e}, response: {response[:200]}") return None def infer_gap_actions(scene_id, before_seg, after_seg, gap_start, gap_end): """Use LLM to infer what actions likely happened in an unlabeled gap.""" scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动") gap_duration = gap_end - gap_start before_text = f"[{before_seg['timestamp']}] {before_seg['task']}" if before_seg else "(录制开始)" after_text = f"[{after_seg['timestamp']}] {after_seg['task']}" if after_seg else "(录制结束)" prompt = f"""你是一个人体动作标注专家。在一段日常活动录制中,有一段时间没有被标注。请根据场景和前后动作推断这段时间内最可能发生的动作。 场景:{scene_desc} 未标注时间段:{gap_start//60:02d}:{gap_start%60:02d} - {gap_end//60:02d}:{gap_end%60:02d}(共{gap_duration}秒) 前一个标注动作:{before_text} 后一个标注动作:{after_text} {ACTION_CATEGORIES} 请推断这段时间内可能发生的动作序列。每个动作段落2-4秒,时间用MM:SS格式。 如果确实是空闲等待,标注为Idle。 严格按以下JSON格式返回,不要添加任何额外文字: [{{"timestamp": "MM:SS-MM:SS", "task": "动作描述", "action": "类别名"}}] 每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle""" response = call_llm(prompt, max_tokens=500) if response is None: return [] try: match = re.search(r'\[.*\]', response, re.DOTALL) if match: results = json.loads(match.group()) # Validate timestamps valid = [] for r in results: if "timestamp" in r and "action" in r and "task" in r: ts_match = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', r["timestamp"]) if ts_match: s = int(ts_match.group(1))*60 + int(ts_match.group(2)) e = int(ts_match.group(3))*60 + int(ts_match.group(4)) if gap_start <= s < e <= gap_end: valid.append(r) return valid except (json.JSONDecodeError, KeyError) as e: print(f" Parse error: {e}") return [] def get_recording_duration(vol, scenario): """Get total recording duration in seconds.""" meta_path = os.path.join(DATASET_DIR, vol, scenario, "alignment_metadata.json") if os.path.exists(meta_path): meta = json.load(open(meta_path)) if "aligned_length_sec" in meta: return meta["aligned_length_sec"] if "aligned_length_frames" in meta: return meta["aligned_length_frames"] / 100.0 return None def process_one_file(ann_path, vol, scenario): """Process one annotation file: reclassify + fill gaps.""" data = json.load(open(ann_path)) segments = data["segments"] if not segments: return data, {"reclassified": 0, "gaps_filled": 0} # Step 1: Reclassify existing segments print(f" Reclassifying {len(segments)} segments...") classifications = reclassify_segments(segments, scenario) if classifications: for i, seg in enumerate(segments): action = classifications.get(i + 1) if action and action in {"Grasp", "Place", "Pour", "Wipe", "Fold", "OpenClose", "Stir", "TearCut", "Arrange", "Transport", "Idle"}: seg["action_label"] = action else: seg["action_label"] = "Idle" else: # Fallback: keep without label for seg in segments: seg["action_label"] = "Idle" reclassified = sum(1 for s in segments if "action_label" in s) # Step 2: Find and fill gaps ≥ 3 seconds # Parse all timestamps parsed = [] for seg in segments: m = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', seg["timestamp"]) if m: s = int(m.group(1))*60 + int(m.group(2)) e = int(m.group(3))*60 + int(m.group(4)) parsed.append((s, e, seg)) parsed.sort() total_dur = get_recording_duration(vol, scenario) new_segments = [] gaps_filled = 0 for i in range(len(parsed)): new_segments.append(parsed[i][2]) # Check gap after this segment if i < len(parsed) - 1: gap_start = parsed[i][1] gap_end = parsed[i + 1][0] elif total_dur: gap_start = parsed[i][1] gap_end = int(total_dur) else: continue gap_duration = gap_end - gap_start if gap_duration >= 3: before_seg = parsed[i][2] after_seg = parsed[i + 1][2] if i < len(parsed) - 1 else None print(f" Filling gap {gap_start}s-{gap_end}s ({gap_duration}s)...") inferred = infer_gap_actions(scenario, before_seg, after_seg, gap_start, gap_end) for inf in inferred: new_seg = { "timestamp": inf["timestamp"], "task": inf["task"], "action_label": inf["action"], "source": "llm_inferred", "left_hand": "", "right_hand": "", "bimanual_interaction": "", "objects": [], } new_segments.append(new_seg) gaps_filled += 1 # Also check gap at the beginning if parsed and parsed[0][0] >= 3: print(f" Filling start gap 0s-{parsed[0][0]}s...") inferred = infer_gap_actions(scenario, None, parsed[0][2], 0, parsed[0][0]) for inf in inferred: new_seg = { "timestamp": inf["timestamp"], "task": inf["task"], "action_label": inf["action"], "source": "llm_inferred", "left_hand": "", "right_hand": "", "bimanual_interaction": "", "objects": [], } new_segments.insert(0, new_seg) gaps_filled += 1 # Sort by timestamp def sort_key(seg): m = re.match(r'(\d+):(\d+)', seg["timestamp"]) return int(m.group(1))*60 + int(m.group(2)) if m else 0 new_segments.sort(key=sort_key) result = copy.deepcopy(data) result["segments"] = new_segments return result, {"reclassified": reclassified, "gaps_filled": gaps_filled} def main(): os.makedirs(OUTPUT_DIR, exist_ok=True) total_reclassified = 0 total_gaps_filled = 0 total_files = 0 for vol_dir in sorted(glob.glob(f"{ANN_DIR}/v*")): vol = os.path.basename(vol_dir) out_vol_dir = os.path.join(OUTPUT_DIR, vol) os.makedirs(out_vol_dir, exist_ok=True) for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")): scenario = os.path.basename(ann_file).replace(".json", "") print(f"\n[{vol}/{scenario}]", flush=True) result, stats = process_one_file(ann_file, vol, scenario) # Save out_path = os.path.join(out_vol_dir, f"{scenario}.json") with open(out_path, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) total_reclassified += stats["reclassified"] total_gaps_filled += stats["gaps_filled"] total_files += 1 print(f" Done: {stats['reclassified']} reclassified, {stats['gaps_filled']} gaps filled", flush=True) print(f"\n{'='*60}") print(f"Total: {total_files} files processed") print(f" Reclassified: {total_reclassified} segments") print(f" Gap-filled: {total_gaps_filled} new segments") print(f" API calls: {call_count}") print(f" Output: {OUTPUT_DIR}") if __name__ == "__main__": main()