|
|
|
|
|
""" |
|
|
Check the function_calls format in Dolci dataset. |
|
|
""" |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
def main(): |
|
|
print("Loading dataset...") |
|
|
dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
|
|
|
|
|
|
|
|
start_idx = len(dataset) - 20000 |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Sample {start_idx}:") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
sample = dataset[start_idx] |
|
|
messages = sample['messages'] |
|
|
|
|
|
for i, msg in enumerate(messages): |
|
|
role = msg.get('role', '') |
|
|
function_calls = msg.get('function_calls') |
|
|
|
|
|
if function_calls: |
|
|
print(f"\nMessage {i} (role={role}):") |
|
|
print(f"function_calls type: {type(function_calls)}") |
|
|
print(f"function_calls content:\n{function_calls}") |
|
|
break |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("Looking for sample with multiple function calls...") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for idx in range(start_idx, start_idx + 1000): |
|
|
sample = dataset[idx] |
|
|
messages = sample['messages'] |
|
|
|
|
|
for msg in messages: |
|
|
function_calls = msg.get('function_calls') |
|
|
if function_calls and function_calls.startswith('['): |
|
|
print(f"\nFound at index {idx}:") |
|
|
print(f"function_calls:\n{function_calls[:500]}...") |
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|