File size: 2,280 Bytes
62cbd63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

import json, os

data_path = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/sft_train_formatted.jsonl"
with open(data_path) as f:
    records = [json.loads(line) for line in f]

print(f"Total samples: {len(records)}")

# Format check
format_ok = 0
format_bad = 0
missing_img = 0
no_cot = 0
short_answer = 0
answer_lens = []
has_think = 0
has_step = 0

for i, r in enumerate(records):
    # Check structure
    if "messages" not in r or len(r["messages"]) < 2:
        format_bad += 1
        continue
    user = r["messages"][0]
    asst = r["messages"][1]
    if user.get("role") != "user" or asst.get("role") != "assistant":
        format_bad += 1
        continue
    
    # Check image
    img_found = False
    for c in user.get("content", []):
        if c.get("type") == "image":
            img_path = c["image"].replace("file://", "")
            if os.path.isfile(img_path):
                img_found = True
            else:
                missing_img += 1
    
    # Check answer
    ans = asst["content"][0]["text"]
    answer_lens.append(len(ans))
    
    if len(ans) < 50:
        short_answer += 1
    
    # Check CoT indicators
    ans_lower = ans.lower()
    if "<think>" in ans_lower or "\\boxed" in ans or "step" in ans_lower:
        has_step += 1
    if "<think>" in ans_lower:
        has_think += 1
    
    format_ok += 1

print(f"\nFormat OK: {format_ok}/{len(records)}")
print(f"Format bad: {format_bad}")
print(f"Missing images: {missing_img}")
print(f"Short answers (<50 chars): {short_answer}")
print(f"Has step/boxed/think: {has_step}/{len(records)}")
print(f"Has <think> tag: {has_think}/{len(records)}")

import statistics
print(f"\nAnswer length stats:")
print(f"  Min: {min(answer_lens)}")
print(f"  Max: {max(answer_lens)}")
print(f"  Mean: {int(statistics.mean(answer_lens))}")
print(f"  Median: {int(statistics.median(answer_lens))}")

# Show 3 random samples
import random
random.seed(42)
samples = random.sample(range(len(records)), 3)
for idx in samples:
    r = records[idx]
    ans = r["messages"][1]["content"][0]["text"]
    q = [c["text"] for c in r["messages"][0]["content"] if c["type"]=="text"][0]
    print(f"\n{'='*50}")
    print(f"Sample {idx}")
    print(f"Q: {q[:120]}...")
    print(f"A ({len(ans)} chars): {ans[:300]}...")