errm / archive /old_scripts /annotate_claude.py
yuffish's picture
Add files using upload-large-folder tool
517964a verified
import os
import json
import base64
from pathlib import Path
from tqdm import tqdm
from dotenv import load_dotenv
from decord import VideoReader
from openai import OpenAI
from PIL import Image
import io
load_dotenv()
# client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
client = OpenAI(
base_url="http://localhost:9950/v1",
api_key="Er21DDnVjatPRb4yE1713Pt8XOmvUA51jSewMBFa91w",
)
# -------- CONFIG --------
ROOT = Path("/home/jqliu/Myprojects/rewardmodel/video")
TARGET_NAME = "22008760.mp4" # 你要提取的视角
OUTPUT_FILE = "./output/labels_batch_improved.jsonl"
FPS_SAMPLE = 2 #10
MODEL_NAME = "anthropic.claude-3-7-sonnet-20250219-v1:0"
MAX_VIDEOS = 2 # 设置为 None 则处理全部
START_INDEX = 0
# 视频序列帧配置
USE_SLIDING_WINDOW = True # 是否使用滑动窗口分析
WINDOW_SIZE = 5 # 每次分析的连续帧数(捕捉动态变化)
WINDOW_STRIDE = 3 # 窗口滑动步长
# -------- PROMPT TEMPLATE --------
PROMPT_TEMPLATE = """
You are a robot manipulation evaluator analyzing the video step-by-step.
**TASK DESCRIPTION**: {task_description}
The task is not only to complete the movement, but also to ensure correct handling of the object.
A task is only considered SUCCESSFUL if:
- The object is securely grasped (not slipping),
- It is moved without spilling, dropping, or losing control,
- And it is placed correctly and stably at the target location.
If the object is spilled, dropped, placed incorrectly, tipped, or ends up unstable → the episode is FAILURE even if the robot completed the motions.
The task typically progresses through these phases (stage index):
0) reach : robot moves toward object (reaching out)
1) grasp : robot attempts to secure object (grasping)
2) lift : robot lifts object upward
3) move : robot carries object toward goal location
4) place : robot releases or places object carefully
5) retract : robot returns to neutral/home position
For each time step shown in the video frames, output:
{{
"t": <frame_index>,
"stage": <0-5, integer stage index>,
"stage_name": "reach" | "grasp" | "lift" | "move" | "place" | "retract",
// Reward components (0.0 to 1.0, aligned with simulation metrics)
"reachout": <0.0 to 1.0, progress in reaching toward object>,
"grasp": <0.0 to 1.0, quality of grasp (0=no contact, 1=secure)>,
"collision": <0.0 to 1.0, penalty for collision (0=no collision, 1=severe collision)>,
"fall": <0.0 to 1.0, penalty for dropping/falling (0=stable, 1=fell/dropped)>,
"smooth": <0.0 to 1.0, smoothness of motion (0=jerky/unstable, 1=very smooth)>,
// Overall metrics
"reward": <0.0 to 1.0, overall reward combining above factors>,
"delta": <-1.0 to 1.0, change in reward from previous step>,
"success_prob": <0.0 to 1.0, probability of eventual success>,
"failure": <0 or 1, 1 if task has failed at this point>,
"explanation": "<brief reasoning about current state, progress, and any issues>"
}}
**IMPORTANT INSTRUCTIONS**:
1. You are given MULTIPLE CONSECUTIVE FRAMES. Analyze the temporal progression:
- Is the motion smooth or jerky?
- Is the object stable or wobbling?
- Is progress being made or has it stalled/reversed?
2. Stage progression should be monotonic (increasing) unless failure occurs.
3. Reward components:
- reachout: Increases as robot approaches object, maxes at 1.0 when contact made
- grasp: 0 until contact, then increases with grasp quality (finger closure, stability)
- collision: Usually 0, increases if robot collides with table/obstacles
- fall: 0 if object stable, 1.0 if object falls/drops
- smooth: High (0.8-1.0) for smooth motion, low (0.0-0.3) for jerky/sudden movements
4. Overall reward should be a weighted combination:
reward ≈ (reachout + grasp + smooth - collision - fall) / 3.0, clamped to [0, 1]
5. For the LAST frame in the sequence:
- If object is dropped, unstable, or incorrectly placed → failure = 1
- If success_prob < 0.5 → consider failure = 1
6. Output a JSON LIST with one entry per frame shown. No extra commentary.
"""
# -------- FUNCTIONS --------
def find_metadata_json(video_path):
"""
根据视频路径查找对应的 metadata JSON 文件
例如: .../recordings/MP4/22008760.mp4
-> .../metadata_*.json
"""
video_path = Path(video_path)
# 向上找到轨迹根目录(包含 recordings 的父目录)
trajectory_dir = video_path.parent.parent.parent
# 查找 metadata_*.json
metadata_files = list(trajectory_dir.glob("metadata_*.json"))
if metadata_files:
return metadata_files[0]
return None
def extract_task_description(metadata_path):
"""从 metadata JSON 中提取任务描述"""
if metadata_path is None or not metadata_path.exists():
return "Unknown task"
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
return metadata.get('current_task', 'Unknown task')
except Exception as e:
print(f"[WARNING] Failed to read metadata {metadata_path}: {e}")
return "Unknown task"
def extract_frames_basic(video_path, fps=FPS_SAMPLE):
"""
基础帧提取:均匀采样
"""
vr = VideoReader(video_path)
total_frames = len(vr)
native_fps = vr.get_avg_fps()
step = max(int(native_fps / fps), 1)
idxs = list(range(0, total_frames, step))
# 强制加入最后一帧(避免丢失成功/失败判别关键画面)
if (total_frames - 1) not in idxs:
idxs.append(total_frames - 1)
idxs = sorted(set(idxs))
frames = vr.get_batch(idxs).asnumpy()
return frames, total_frames, idxs, native_fps
def extract_frame_windows(video_path, fps=FPS_SAMPLE, window_size=WINDOW_SIZE, stride=WINDOW_STRIDE):
"""
滑动窗口帧提取:每次提取连续的 window_size 帧,滑动 stride 步
这样可以让模型看到连续的动作序列,更好地判断运动的平滑性和趋势
Returns:
windows: List of (frame_arrays, frame_indices) tuples
total_frames: 总帧数
native_fps: 原始帧率
"""
vr = VideoReader(video_path)
total_frames = len(vr)
native_fps = vr.get_avg_fps()
step = max(int(native_fps / fps), 1)
# 首先进行降采样
sampled_idxs = list(range(0, total_frames, step))
if (total_frames - 1) not in sampled_idxs:
sampled_idxs.append(total_frames - 1)
sampled_idxs = sorted(set(sampled_idxs))
# 创建滑动窗口
windows = []
for i in range(0, len(sampled_idxs), stride):
window_idxs = sampled_idxs[i:i + window_size]
if len(window_idxs) > 0:
window_frames = vr.get_batch(window_idxs).asnumpy()
windows.append((window_frames, window_idxs))
return windows, total_frames, native_fps
def encode_image(image_array):
"""将numpy数组编码为base64"""
img = Image.fromarray(image_array)
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode("utf-8")
def call_model_basic(frames, task_description):
"""
基础模式:一次性发送所有帧
"""
imgs = [
{
"type": "image",
# "source": {"type": "url", "url": f"data:image/jpeg;base64,{encode_image(f)}"}
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": f"{encode_image(f)}"
}
# "image": {"url": f"data:image/jpeg;base64,{encode_image(f)}"}
}
for f in frames
]
prompt = PROMPT_TEMPLATE.format(task_description=task_description)
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": imgs}
],
# temperature=0.1, # 降低温度以获得更一致的结果
)
return json.loads(response.choices[0].message.content)
def call_model_window(window_frames, window_idxs, task_description, context=""):
"""
窗口模式:发送连续的几帧进行分析
Args:
window_frames: 窗口内的帧数组
window_idxs: 对应的原始帧索引
task_description: 任务描述
context: 前一个窗口的上下文信息(可选)
"""
imgs = [
{
"type": "image",
# "source": {"type": "url", "url": f"data:image/jpeg;base64,{encode_image(f)}"}
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": f"{encode_image(f)}"
}
# "image": {"url": f"data:image/jpeg;base64,{encode_image(f)}"}
}
for f in window_frames
]
prompt = PROMPT_TEMPLATE.format(task_description=task_description)
if context:
prompt += f"\n\n**CONTEXT FROM PREVIOUS WINDOW**:\n{context}\n"
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": [
{"type": "text", "text": f"Analyzing frames at indices: {window_idxs}"},
*imgs
]}
],
# temperature=0.1,
)
return json.loads(response.choices[0].message.content)
def process_video_basic(vid_path, task_description):
"""基础处理模式:一次性分析所有帧"""
frames, total_frames, idxs, native_fps = extract_frames_basic(str(vid_path))
result = call_model_basic(frames, task_description)
# 整理输出
entries = []
for i, step_data in enumerate(result):
entry = {
"video_path": str(vid_path),
"video_id": vid_path.stem,
"task": task_description,
"t": i,
"frame_index": int(idxs[i]) if i < len(idxs) else idxs[-1],
"total_frames": int(total_frames),
"native_fps": float(native_fps),
**step_data
}
entries.append(entry)
return entries
def process_video_sliding_window(vid_path, task_description):
"""滑动窗口处理模式:连续帧分析"""
windows, total_frames, native_fps = extract_frame_windows(str(vid_path))
all_entries = []
context = ""
for window_idx, (window_frames, window_idxs) in enumerate(windows):
try:
result = call_model_window(window_frames, window_idxs, task_description, context)
# 更新context(取最后一个结果的explanation作为上下文)
if result and len(result) > 0:
last_result = result[-1]
context = f"Previous stage: {last_result.get('stage_name', 'unknown')}, " \
f"Success prob: {last_result.get('success_prob', 0):.2f}, " \
f"Explanation: {last_result.get('explanation', '')}"
# 整理输出
for i, step_data in enumerate(result):
if i < len(window_idxs):
entry = {
"video_path": str(vid_path),
"video_id": vid_path.stem,
"task": task_description,
"t": len(all_entries), # 全局时间步
"frame_index": int(window_idxs[i]),
"total_frames": int(total_frames),
"native_fps": float(native_fps),
"window_idx": window_idx,
**step_data
}
all_entries.append(entry)
except Exception as e:
print(f"[WARNING] Window {window_idx} failed: {e}")
continue
return all_entries
# -------- FIND ALL TARGET VIDEOS --------
video_files = sorted(ROOT.rglob(f"{TARGET_NAME}"))
print(video_files)
total_videos = len(video_files)
if START_INDEX >= total_videos:
raise ValueError(f"START_INDEX {START_INDEX} 超出视频总数 {total_videos}")
# 截取指定范围
if MAX_VIDEOS is None:
video_files = video_files[START_INDEX:]
else:
video_files = video_files[START_INDEX:START_INDEX + MAX_VIDEOS]
print(f"Found {len(video_files)} videos to process (from index {START_INDEX}):")
for v in video_files[:5]:
print(" -", v)
if len(video_files) > 5:
print(f" ... and {len(video_files) - 5} more")
print(f"\nProcessing mode: {'Sliding Window' if USE_SLIDING_WINDOW else 'Basic (all frames)'}")
if USE_SLIDING_WINDOW:
print(f"Window size: {WINDOW_SIZE} frames, Stride: {WINDOW_STRIDE}")
# -------- PROCESS LOOP --------
os.makedirs(Path(OUTPUT_FILE).parent, exist_ok=True)
with open(OUTPUT_FILE, "w") as fout:
for vid_path in tqdm(video_files, desc="Processing videos"):
try:
# 查找并提取任务描述
metadata_path = find_metadata_json(vid_path)
task_description = extract_task_description(metadata_path)
print(f"\n[INFO] Processing: {vid_path.name}")
print(f"[INFO] Task: {task_description}")
print(f"[INFO] Metadata: {metadata_path}")
# 根据配置选择处理模式
if USE_SLIDING_WINDOW:
entries = process_video_sliding_window(vid_path, task_description)
else:
entries = process_video_basic(vid_path, task_description)
# 写入结果
for entry in entries:
fout.write(json.dumps(entry) + "\n")
print(f"[SUCCESS] Processed {len(entries)} annotations")
except Exception as e:
print(f"[ERROR] {vid_path}: {e}")
import traceback
traceback.print_exc()
print(f"\n✓ Processing complete! Results saved to: {OUTPUT_FILE}")