AMA-bench-Leaderboard / view_samples.py
NorahYujieZhao
the new version
d8b2e03
#!/usr/bin/env python3
"""
View sample records from the processed JSONL file.
"""
import json
import sys
from pathlib import Path
def print_record(data, show_full=False):
"""
Print a single record in a readable format.
"""
print("=" * 80)
print(f"Episode ID: {data['episode_id']}")
print(f"Task Type: {data['task_type']}")
print(f"Domain: {data['domain']}")
print(f"Success: {data['success']}")
print(f"Turns: {data['num_turns']}")
print(f"Tokens: {data['total_tokens']}")
if data['task']:
task_preview = data['task'][:150]
print(f"\nTask:\n{task_preview}..." if len(data['task']) > 150 else f"\nTask:\n{task_preview}")
print(f"\nQA Pairs: {len(data['qa_pairs'])}")
if show_full:
print("\nAll QA Pairs:")
print("-" * 80)
for i, qa in enumerate(data['qa_pairs'], 1):
print(f"\n[{i}] Type: {qa['type']}", end="")
if 'sub_type' in qa:
print(f" / Subtype: {qa['sub_type']}")
else:
print()
print(f"Q: {qa['question'][:120]}...")
print(f"A: {qa['answer'][:120]}...")
else:
# Show first 2 QA pairs as preview
print("\nSample QA Pairs (first 2):")
print("-" * 80)
for i, qa in enumerate(data['qa_pairs'][:2], 1):
print(f"\n[{i}] Type: {qa['type']}", end="")
if 'sub_type' in qa:
print(f" / Subtype: {qa['sub_type']}")
else:
print()
print(f"Q: {qa['question'][:120]}...")
print(f"A: {qa['answer'][:120]}...")
if data['trajectory']:
print(f"\nTrajectory: {len(data['trajectory'])} turns")
if show_full and len(data['trajectory']) > 0:
print("\nFirst 3 turns:")
print("-" * 80)
for turn in data['trajectory'][:3]:
print(f"\nTurn {turn['turn_idx']}:")
action = str(turn['action'])[:100] if turn['action'] else "None"
observation = str(turn['observation'])[:100] if turn['observation'] else "None"
print(f" Action: {action}...")
print(f" Observation: {observation}...")
print("=" * 80)
print()
def view_by_task_type(file_path: Path, task_type: str, count: int = 3):
"""
View samples of a specific task type.
"""
print(f"\nShowing {count} samples for task type: {task_type}\n")
shown = 0
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
if data['task_type'] == task_type:
print_record(data, show_full=False)
shown += 1
if shown >= count:
break
if shown == 0:
print(f"No records found for task type: {task_type}")
def view_by_index(file_path: Path, index: int):
"""
View a specific record by index (0-based).
"""
with open(file_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i == index:
data = json.loads(line)
print_record(data, show_full=True)
return
print(f"Index {index} not found (file has fewer records)")
def list_task_types(file_path: Path):
"""
List all unique task types in the file.
"""
task_types = set()
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
task_types.add(data['task_type'])
print("\nAvailable task types:")
print("-" * 80)
for i, task_type in enumerate(sorted(task_types), 1):
print(f" {i:2d}. {task_type}")
print()
def main():
jsonl_file = Path(__file__).parent / "processed_open_end.jsonl"
if not jsonl_file.exists():
print(f"Error: {jsonl_file} not found!")
print("Please run process_open_end.py first.")
exit(1)
# Command line interface
if len(sys.argv) < 2:
print("Usage:")
print(" python3 view_samples.py list # List all task types")
print(" python3 view_samples.py index <n> # View record at index n")
print(" python3 view_samples.py type <task_type> [n] # View n samples of task type (default 3)")
print("\nExamples:")
print(" python3 view_samples.py list")
print(" python3 view_samples.py index 0")
print(" python3 view_samples.py type text2sql/spider2 5")
return
command = sys.argv[1]
if command == "list":
list_task_types(jsonl_file)
elif command == "index":
if len(sys.argv) < 3:
print("Error: Please specify an index")
return
try:
index = int(sys.argv[2])
view_by_index(jsonl_file, index)
except ValueError:
print("Error: Index must be an integer")
elif command == "type":
if len(sys.argv) < 3:
print("Error: Please specify a task type")
return
task_type = sys.argv[2]
count = 3
if len(sys.argv) >= 4:
try:
count = int(sys.argv[3])
except ValueError:
print("Error: Count must be an integer")
return
view_by_task_type(jsonl_file, task_type, count)
else:
print(f"Unknown command: {command}")
print("Use: list, index, or type")
if __name__ == "__main__":
main()