PULSE-code / experiments /analysis /generate_coarse_annotations.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Generate coarse-grained annotations by merging consecutive fine-grained segments
into composite actions (8-15s duration) using LLM.
Input: annotations_v2/ (fine-grained, ~2-3s segments, 11 classes)
Output: annotations_coarse/ (coarse-grained, ~8-15s segments, ~6 classes)
Does NOT modify annotations_v2/.
"""
import os
import json
import re
import time
import glob
import urllib.request
from collections import Counter
INPUT_DIR = "${PULSE_ROOT}/annotations_v2"
OUTPUT_DIR = "${PULSE_ROOT}/annotations_coarse"
API_URL = "https://api.chatanywhere.tech/v1/chat/completions"
API_KEYS = [
"sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ",
"sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY",
"sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs",
"sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB",
"sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk",
"sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH",
"sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz",
"sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu",
"sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz",
"sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ",
]
SCENE_DESCRIPTIONS = {
"s1": "办公桌面整理与工作准备",
"s2": "快递打包发送",
"s3": "厨房调料整理",
"s4": "清理餐后桌面",
"s5": "餐前桌面布置",
"s6": "商务旅行行李箱打包",
"s7": "冲泡咖啡/饮品",
"s8": "晾衣架整理与衣物收纳",
}
COARSE_CATEGORIES = """粗粒度动作类别(共6类):
1. Manipulate - 操作物体(抓取、调整、放置某个物体的完整过程,包含拿起→操作→放下的组合)
2. CleanOrganize - 清洁/整理(擦桌子、理线、整理桌面、叠衣服等持续性整理活动)
3. Transfer - 搬运/传递(将物体从一个位置搬到另一个位置的过程)
4. Assemble - 组装/连接/包装(封箱、贴胶带、盖盖子、插电源、拧瓶盖等需要精细对准的操作)
5. FoodPrep - 食物/饮品准备(倒水、倒调料、搅拌、冲泡等与食物饮品相关的操作)
6. Idle - 空闲/过渡(无明确操作的间隔)
"""
current_key_idx = 0
call_count = 0
def call_llm(prompt, max_tokens=1500, retries=3):
global current_key_idx, call_count
for attempt in range(retries * len(API_KEYS)):
key = API_KEYS[current_key_idx]
try:
data = json.dumps({
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.1,
}).encode()
req = urllib.request.Request(
API_URL, data=data,
headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
)
resp = urllib.request.urlopen(req, timeout=30)
result = json.loads(resp.read())
call_count += 1
return result["choices"][0]["message"]["content"]
except Exception as e:
err = str(e)
if any(k in err for k in ["429", "quota", "limit", "402", "403"]):
current_key_idx = (current_key_idx + 1) % len(API_KEYS)
else:
time.sleep(0.5)
current_key_idx = (current_key_idx + 1) % len(API_KEYS)
return None
def parse_ts(ts_str):
"""Parse 'MM:SS' to seconds."""
m = re.match(r'(\d+):(\d+)', ts_str.strip())
if m:
return int(m.group(1)) * 60 + int(m.group(2))
return 0
def format_ts(sec):
"""Format seconds to 'MM:SS'."""
return f"{sec//60:02d}:{sec%60:02d}"
def merge_segments_with_llm(segments, scene_id):
"""Use LLM to merge fine-grained segments into coarse composite actions."""
scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
# Build segment list
seg_lines = []
for i, seg in enumerate(segments):
label = seg.get("action_label", "Idle")
seg_lines.append(f"{i+1}. [{seg['timestamp']}] {label}: {seg['task']}")
seg_text = "\n".join(seg_lines)
prompt = f"""你是一个动作标注专家。以下是一段"{scene_desc}"录制中的细粒度动作序列(每个2-3秒)。
请将相关的连续动作合并为粗粒度复合动作,每个复合动作持续5-15秒。
合并规则:
- 围绕同一个物体的连续操作合并为一个(如"抓取杯子→调整→放下"合并为一个Manipulate)
- 连续的整理/清洁动作合并
- 合并后的时间范围 = 第一个子动作的开始时间 到 最后一个子动作的结束时间
- 如果中间有短暂Idle(≤3秒),可以包含进去
- 每个复合动作必须从6个类别中选一个
{COARSE_CATEGORIES}
细粒度动作序列:
{seg_text}
请严格按以下JSON格式返回,不要添加任何额外文字:
[{{"timestamp": "MM:SS-MM:SS", "coarse_action": "类别名", "description": "简要描述这段复合动作", "fine_segments": [子动作编号列表]}}]"""
response = call_llm(prompt, max_tokens=2000)
if response is None:
return None
try:
match = re.search(r'\[.*\]', response, re.DOTALL)
if match:
results = json.loads(match.group())
valid = []
for r in results:
if all(k in r for k in ["timestamp", "coarse_action", "description"]):
# Validate category
if r["coarse_action"] in {"Manipulate", "CleanOrganize", "Transfer",
"Assemble", "FoodPrep", "Idle"}:
valid.append(r)
return valid
except (json.JSONDecodeError, KeyError) as e:
print(f" Parse error: {e}")
return None
def process_file(input_path, vol, scenario):
"""Process one annotation file."""
data = json.load(open(input_path))
segments = data["segments"]
if not segments:
return {"fine_segments": segments, "coarse_segments": []}, 0
print(f" Merging {len(segments)} fine segments...")
coarse = merge_segments_with_llm(segments, scenario)
if coarse is None:
# Fallback: simple time-based merging without LLM
print(f" LLM failed, using fallback merge")
coarse = fallback_merge(segments)
result = {
"fine_segments": segments,
"coarse_segments": coarse,
}
return result, len(coarse)
def fallback_merge(segments):
"""Simple rule-based merging as fallback."""
if not segments:
return []
coarse = []
group = [segments[0]]
for seg in segments[1:]:
# Parse timestamps
prev_ts = group[-1]["timestamp"]
curr_ts = seg["timestamp"]
m1 = re.match(r'(\d+:\d+)\s*-\s*(\d+:\d+)', prev_ts)
m2 = re.match(r'(\d+:\d+)\s*-\s*(\d+:\d+)', curr_ts)
if not m1 or not m2:
group.append(seg)
continue
prev_end = parse_ts(m1.group(2))
curr_start = parse_ts(m2.group(1))
gap = curr_start - prev_end
# Merge if gap ≤ 3s and group duration < 15s
group_start = parse_ts(re.match(r'(\d+:\d+)', group[0]["timestamp"]).group(1))
curr_end = parse_ts(m2.group(2))
group_duration = curr_end - group_start
if gap <= 3 and group_duration <= 15:
group.append(seg)
else:
# Emit current group
coarse.append(_emit_group(group))
group = [seg]
if group:
coarse.append(_emit_group(group))
return coarse
def _emit_group(group):
"""Create a coarse segment from a group of fine segments."""
m_start = re.match(r'(\d+:\d+)', group[0]["timestamp"])
m_end = re.match(r'\d+:\d+\s*-\s*(\d+:\d+)', group[-1]["timestamp"])
start = m_start.group(1) if m_start else "00:00"
end = m_end.group(1) if m_end else "00:00"
labels = [seg.get("action_label", "Idle") for seg in group]
label_counts = Counter(labels)
dominant = label_counts.most_common(1)[0][0]
# Map fine label to coarse
label_map = {
"Grasp": "Manipulate", "Place": "Manipulate", "Arrange": "CleanOrganize",
"Wipe": "CleanOrganize", "Fold": "CleanOrganize", "Transport": "Transfer",
"OpenClose": "Assemble", "TearCut": "Assemble",
"Pour": "FoodPrep", "Stir": "FoodPrep", "Idle": "Idle",
}
coarse_label = label_map.get(dominant, "Manipulate")
tasks = [seg["task"] for seg in group]
desc = tasks[0] if len(tasks) == 1 else f"{tasks[0]}...{tasks[-1]}"
return {
"timestamp": f"{start}-{end}",
"coarse_action": coarse_label,
"description": desc[:80],
"fine_segments": list(range(1, len(group) + 1)),
}
def main():
os.makedirs(OUTPUT_DIR, exist_ok=True)
total_fine = 0
total_coarse = 0
total_files = 0
coarse_labels = Counter()
for vol_dir in sorted(glob.glob(f"{INPUT_DIR}/v*")):
vol = os.path.basename(vol_dir)
out_dir = os.path.join(OUTPUT_DIR, vol)
os.makedirs(out_dir, exist_ok=True)
for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")):
scenario = os.path.basename(ann_file).replace(".json", "")
print(f"[{vol}/{scenario}]", flush=True)
result, n_coarse = process_file(ann_file, vol, scenario)
out_path = os.path.join(out_dir, f"{scenario}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
n_fine = len(result["fine_segments"])
total_fine += n_fine
total_coarse += n_coarse
total_files += 1
for seg in result["coarse_segments"]:
coarse_labels[seg["coarse_action"]] += 1
print(f" {n_fine} fine → {n_coarse} coarse segments", flush=True)
print(f"\n{'='*60}")
print(f"Total: {total_files} files")
print(f" Fine segments: {total_fine}")
print(f" Coarse segments: {total_coarse}")
print(f" Compression: {total_fine/max(total_coarse,1):.1f}x")
print(f" API calls: {call_count}")
print(f"\n Coarse label distribution:")
for label, count in coarse_labels.most_common():
print(f" {label:<20} {count:>5} ({count/max(total_coarse,1)*100:.1f}%)")
print(f"\n Output: {OUTPUT_DIR}")
if __name__ == "__main__":
main()