errm / archive /old_scripts /api_batch.py
yuffish's picture
Add files using upload-large-folder tool
517964a verified
import os
import json
import base64
from pathlib import Path
from tqdm import tqdm
from dotenv import load_dotenv
from decord import VideoReader
from openai import OpenAI
from PIL import Image
import io
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# -------- CONFIG --------
ROOT = Path("/playpen-ssd/dataset/droid_raw/1.0.1/AUTOLab/failure")
TARGET_NAME = "22008760.mp4" # 你要提取的视角
OUTPUT_FILE = "./output/labels_batch_1111.jsonl"
FPS_SAMPLE = 2
MODEL_NAME = "gpt-5-mini"
MAX_VIDEOS = 10 # 设置为 None 则处理全部
START_INDEX = 20
# -------- PROMPT --------
PROMPT = """
You are a robot manipulation evaluator analyzing the video step-by-step.
The task is not only to complete the movement, but also to ensure correct handling of the object.
A task is only considered SUCCESSFUL if:
- The object is securely grasped (not slipping),
- It is moved without spilling, dropping, or losing control,
- And it is placed correctly and stably at the target location.
If the object is spilled, dropped, placed incorrectly, tipped, or ends up unstable → the episode is FAILURE even if the robot completed the motions.
The task typically progresses through these phases:
1) reach : robot moves toward object
2) grasp : robot attempts to secure object
3) up : robot lifts object
4) move : robot carries object toward goal
5) place : robot releases or places object carefully
6) return : optional return to neutral state
For each time step, output:
{
"t": <index>,
"stage": "reach" | "grasp" | "up" | "move" | "place" | "return",
"reward": <0 to 1>,
"delta": <-1 to 1>,
"success_prob": <0 to 1>,
"failure": <0 or 1>,
"explanation": "<brief reasoning>"
}
Rules:
- Stage should progress logically unless failure occurs.
- reward increases as progress improves and decreases when mistakes occur.
- If object is dropped, spilled, crushed, knocked over, or final state is unstable → failure = 1.
- For the LAST time step:
If success_prob is low OR the object is not placed correctly/stably,
FORCE failure = 1.
Output JSON LIST only. No extra commentary.
"""
# -------- FUNCTIONS --------
# def extract_frames(video_path, fps=FPS_SAMPLE):
# vr = VideoReader(video_path)
# total_frames = len(vr)
# native_fps = vr.get_avg_fps()
# step = max(int(native_fps / fps), 1)
# idxs = list(range(0, total_frames, step))
# frames = vr.get_batch(idxs).asnumpy()
# return frames, total_frames, idxs
def extract_frames(video_path, fps=FPS_SAMPLE):
vr = VideoReader(video_path)
total_frames = len(vr)
native_fps = vr.get_avg_fps()
step = max(int(native_fps / fps), 1)
idxs = list(range(0, total_frames, step))
# ✅ 强制加入最后一帧(避免丢失成功/失败判别关键画面)
if (total_frames - 1) not in idxs:
idxs.append(total_frames - 1)
idxs = sorted(set(idxs))
frames = vr.get_batch(idxs).asnumpy()
return frames, total_frames, idxs
def encode_image(image_array):
img = Image.fromarray(image_array)
buf = io.BytesIO()
img.save(buf, format="JPEG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def call_model(frames):
imgs = [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encode_image(f)}"}
}
for f in frames
]
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": PROMPT},
{"role": "user", "content": imgs}
]
)
return json.loads(response.choices[0].message.content)
# -------- FIND ALL TARGET VIDEOS --------
# video_files = sorted(ROOT.rglob(f"*/recordings/MP4/{TARGET_NAME}"))
# if MAX_VIDEOS is not None:
# video_files = video_files[:MAX_VIDEOS]
# print(f"Found {len(video_files)} videos to process.")
video_files = sorted(ROOT.rglob(f"*/recordings/MP4/{TARGET_NAME}"))
total_videos = len(video_files)
if START_INDEX >= total_videos:
raise ValueError(f"START_INDEX {START_INDEX} 超出视频总数 {total_videos}")
# ✅ 截取指定范围
if MAX_VIDEOS is None:
video_files = video_files[START_INDEX:]
else:
video_files = video_files[START_INDEX:START_INDEX + MAX_VIDEOS]
print(f"Found {len(video_files)} videos to process (from index {START_INDEX}):")
for v in video_files:
print(" -", v)
# -------- PROCESS LOOP --------
os.makedirs(Path(OUTPUT_FILE).parent, exist_ok=True)
with open(OUTPUT_FILE, "w") as fout:
for vid_path in tqdm(video_files, desc="Processing videos"):
try:
frames, total_frames, idxs = extract_frames(str(vid_path))
result = call_model(frames)
for i, step_data in enumerate(result):
entry = {
"video_path": str(vid_path), # ✅ 保存原始视频路径
"video_id": vid_path.stem,
"t": i,
"frame_index": int(idxs[i]), # ✅ 原始帧编号
"total_frames": int(total_frames),
**step_data
}
fout.write(json.dumps(entry) + "\n")
except Exception as e:
print(f"[ERROR] {vid_path}: {e}")