File size: 2,280 Bytes
62cbd63 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import json, os
data_path = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/sft_train_formatted.jsonl"
with open(data_path) as f:
records = [json.loads(line) for line in f]
print(f"Total samples: {len(records)}")
# Format check
format_ok = 0
format_bad = 0
missing_img = 0
no_cot = 0
short_answer = 0
answer_lens = []
has_think = 0
has_step = 0
for i, r in enumerate(records):
# Check structure
if "messages" not in r or len(r["messages"]) < 2:
format_bad += 1
continue
user = r["messages"][0]
asst = r["messages"][1]
if user.get("role") != "user" or asst.get("role") != "assistant":
format_bad += 1
continue
# Check image
img_found = False
for c in user.get("content", []):
if c.get("type") == "image":
img_path = c["image"].replace("file://", "")
if os.path.isfile(img_path):
img_found = True
else:
missing_img += 1
# Check answer
ans = asst["content"][0]["text"]
answer_lens.append(len(ans))
if len(ans) < 50:
short_answer += 1
# Check CoT indicators
ans_lower = ans.lower()
if "<think>" in ans_lower or "\\boxed" in ans or "step" in ans_lower:
has_step += 1
if "<think>" in ans_lower:
has_think += 1
format_ok += 1
print(f"\nFormat OK: {format_ok}/{len(records)}")
print(f"Format bad: {format_bad}")
print(f"Missing images: {missing_img}")
print(f"Short answers (<50 chars): {short_answer}")
print(f"Has step/boxed/think: {has_step}/{len(records)}")
print(f"Has <think> tag: {has_think}/{len(records)}")
import statistics
print(f"\nAnswer length stats:")
print(f" Min: {min(answer_lens)}")
print(f" Max: {max(answer_lens)}")
print(f" Mean: {int(statistics.mean(answer_lens))}")
print(f" Median: {int(statistics.median(answer_lens))}")
# Show 3 random samples
import random
random.seed(42)
samples = random.sample(range(len(records)), 3)
for idx in samples:
r = records[idx]
ans = r["messages"][1]["content"][0]["text"]
q = [c["text"] for c in r["messages"][0]["content"] if c["type"]=="text"][0]
print(f"\n{'='*50}")
print(f"Sample {idx}")
print(f"Q: {q[:120]}...")
print(f"A ({len(ans)} chars): {ans[:300]}...")
|