Spaces:
Running on Zero
Running on Zero
| """ | |
| Check that answer messages correctly reuse tool responses and don't duplicate thinking. | |
| This script validates two key properties of the distilled dataset: | |
| 1. Answer Reuse: Final answers MUST reference the immediately preceding tool response, | |
| not be manually recalculated | |
| 2. Unique Thinking: Final answer messages should NOT duplicate the thinking process | |
| from the tool-calling message (indicates distillation issues) | |
| Examples flagged: | |
| - Answer doesn't match tool response value | |
| - Previous message is NOT a tool response | |
| - Identical <think> blocks in tool-calling and final answer messages | |
| """ | |
| import copy | |
| import json | |
| import re | |
| import shutil | |
| import sys | |
| from io import StringIO | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| from datasets import Dataset, DownloadMode, load_dataset, load_from_disk | |
| from transformers import AutoTokenizer | |
| from linalg_zero.grpo.verifiers.xml_parser import XMLParser | |
| from linalg_zero.grpo.verify import parse_string, verify_answers | |
| output_buffer = StringIO() | |
| local_dataset_path = "atomwalk12/linalgzero-distilled-local" | |
| if not Path(local_dataset_path).exists(): | |
| dataset = load_dataset("atomwalk12/linalgzero-distilled", "default", split="train") | |
| dataset.save_to_disk(local_dataset_path) | |
| def parse_messages(messages_field: Any) -> list[dict]: | |
| """ | |
| Parse messages from dataset field (handles both JSON string and dict formats). | |
| Args: | |
| messages_field: Messages field from dataset (can be string or list) | |
| Returns: | |
| List of message dictionaries | |
| """ | |
| return json.loads(messages_field) if isinstance(messages_field, str) else messages_field | |
| def get_message_content(message: dict) -> str: | |
| """ | |
| Extract and clean content from a message dictionary. | |
| Args: | |
| message: Message dictionary with 'content' field | |
| Returns: | |
| Stripped message content string | |
| """ | |
| return (message.get("content") or "").strip() | |
| def load_tokenizer(tokenizer_name: str) -> AutoTokenizer: | |
| """ | |
| Load a tokenizer by name. | |
| Args: | |
| tokenizer_name: HuggingFace model name for the tokenizer | |
| Returns: | |
| Loaded AutoTokenizer instance | |
| """ | |
| return AutoTokenizer.from_pretrained(tokenizer_name) | |
| def count_tokens_in_message(content: str, tokenizer: AutoTokenizer) -> int: | |
| """Count tokens in a message content using the provided tokenizer.""" | |
| return len(tokenizer.encode(content)) | |
| def get_max_assistant_tokens(messages: list[dict], tokenizer: AutoTokenizer) -> int: | |
| """ | |
| Get the maximum token count among all assistant messages. | |
| Args: | |
| messages: List of message dictionaries | |
| tokenizer: Tokenizer to use for counting tokens | |
| Returns: | |
| Maximum token count found in assistant messages, or 0 if no assistant messages | |
| """ | |
| assistant_token_counts = [] | |
| for msg in messages: | |
| if msg.get("role") == "assistant": | |
| content = get_message_content(msg) | |
| token_count = count_tokens_in_message(content, tokenizer) | |
| assistant_token_counts.append(token_count) | |
| return max(assistant_token_counts) if assistant_token_counts else 0 | |
| def print_to_both(*args, **kwargs): | |
| """Print to both console and buffer.""" | |
| print(*args, **kwargs) | |
| if output_buffer: | |
| print(*args, **kwargs, file=output_buffer) | |
| def extract_answer_value(content: str, parser: XMLParser) -> str | None: | |
| """Extract the value inside <answer> tags.""" | |
| answers = parser.extract_tag_contents(content, "answer") | |
| # Return the last answer if multiple exist | |
| return answers[-1] if answers else None | |
| def check_tool_call_is_skipped_due_to_simple_op( | |
| messages: list[dict], parser: XMLParser, minimal_dataset: bool = False | |
| ) -> tuple[bool, str]: | |
| """ | |
| Check if answer messages correctly reference the immediately preceding tool response. | |
| Returns: | |
| (has_issue, reason): bool indicating if there's an issue, and string with reason | |
| """ | |
| answer_msg = messages[-1] | |
| first_to_last_msg = messages[-2] | |
| if first_to_last_msg["role"] != "tool": | |
| return False, "No tool response before final answer" | |
| answer = parse_string(parser.extract_last_answer(answer_msg["content"])) | |
| tool_response = parse_string(first_to_last_msg["content"]) | |
| if verify_answers(tool_response, answer): | |
| return False, "Answer matches tool response" | |
| else: | |
| if first_to_last_msg["name"] == "determinant": | |
| if answer in [1, 2, 3] and not minimal_dataset: | |
| return False, "" | |
| else: | |
| return ( | |
| True, | |
| "Interesting case: using determinant as the last tool call, but final answer is not a possible rank", | |
| ) | |
| else: | |
| return ( | |
| True, | |
| f"Answer does not match tool response: {answer} (final answer) != {tool_response} (final tool response -- {first_to_last_msg['name']})\n", | |
| ) | |
| def check_repeated_tool_calls(messages: list[dict], parser: XMLParser) -> tuple[bool, list[dict]]: | |
| """ | |
| Check if the same tool function is called multiple times in a conversation. | |
| Returns: | |
| (has_repeated_calls, repeated_calls_info): | |
| - bool indicating if there are repeated calls | |
| - List of dicts with tool_name, positions, and count | |
| """ | |
| tool_calls = {} # {tool_name: [position_indices]} | |
| for i, msg in enumerate(messages): | |
| if msg["role"] == "assistant" and msg.get("tool_calls") and len(msg["tool_calls"]) > 0: | |
| msg_calls = msg["tool_calls"] | |
| assert len(msg_calls) == 1, "Expected only one tool call per assistant message" | |
| tool_name = msg_calls[0]["function"]["name"] | |
| if tool_name not in tool_calls: | |
| tool_calls[tool_name] = [] | |
| tool_calls[tool_name].append(i) | |
| # Find tools called more than once | |
| repeated = [] | |
| for tool_name, positions in tool_calls.items(): | |
| if len(positions) > 1: | |
| repeated.append({"tool_name": tool_name, "positions": positions, "count": len(positions)}) | |
| return (len(repeated) > 0, repeated) | |
| def check_all_think_duplicates(messages: list[dict], parser: XMLParser) -> tuple[list[dict], list[dict]]: | |
| """ | |
| Find ALL duplicate <think> blocks across all message pairs (including non-adjacent). | |
| Keep only the first occurrence and replace subsequent duplicates with placeholder. | |
| Returns: | |
| (duplicates, modified_messages): | |
| - List of duplicate info dicts with positions and adjacency info | |
| - Modified messages list with replacements applied | |
| """ | |
| duplicates = [] | |
| modified_messages = [msg.copy() for msg in messages] # Deep copy to avoid mutating original | |
| # Extract all <think> blocks from assistant messages | |
| thinks = [] | |
| for i, msg in enumerate(messages): | |
| if msg["role"] == "assistant": | |
| think_content = parser.extract_tag_contents(msg["content"], "think") | |
| if think_content and think_content[0].strip(): # Non-empty | |
| thinks.append((i, think_content[0].strip())) | |
| # Track which think blocks have been seen (by content) | |
| seen_thinks = {} # content -> first position | |
| positions_to_replace = [] # positions to replace with placeholder | |
| # Compare all pairs (including non-adjacent) | |
| for i in range(len(thinks)): | |
| for j in range(i + 1, len(thinks)): | |
| pos_i, content_i = thinks[i] | |
| pos_j, content_j = thinks[j] | |
| if content_i == content_j: | |
| duplicates.append({ | |
| "positions": (pos_i, pos_j), | |
| "adjacent": (pos_j - pos_i == 2), # assistant → tool → assistant | |
| "content_preview": content_i[:100], | |
| }) | |
| # Mark the later occurrence for replacement | |
| if content_i not in seen_thinks: | |
| seen_thinks[content_i] = pos_i | |
| # Only replace if this is not the first occurrence | |
| if pos_j not in [seen_thinks[c] for c in seen_thinks]: | |
| positions_to_replace.append(pos_j) | |
| # Replace duplicate think blocks with placeholder | |
| for pos in positions_to_replace: | |
| old_content = modified_messages[pos]["content"] | |
| # Replace the <think>...</think> content with placeholder | |
| new_content = re.sub( | |
| r"<think>.*?</think>", "<think>Finalise based on tool result</think>", old_content, flags=re.DOTALL | |
| ) | |
| modified_messages[pos]["content"] = new_content | |
| return duplicates, modified_messages | |
| def check_exact_sequence( | |
| messages: list[dict], parser: XMLParser, stepwise_ground_truths: list[dict] | |
| ) -> tuple[bool, list[dict]]: | |
| """ | |
| Check if the sequence of messages matches the stepwise ground truths. | |
| Returns: | |
| (has_issue, issues_info): bool indicating if there are mismatches, | |
| and list of dicts with detailed error information | |
| """ | |
| issues = [] | |
| for step_idx, ground_truth in enumerate(stepwise_ground_truths): | |
| msg_idx = 2 + (step_idx * 2) # Tool call messages are at indices 2, 4, 6, ... | |
| if msg_idx >= len(messages): | |
| issues.append({ | |
| "step": step_idx, | |
| "error": "Missing tool call", | |
| "expected_tools": list(ground_truth.keys()), | |
| "actual_tool": None, | |
| "message_index": msg_idx, | |
| }) | |
| continue | |
| # Check if it's actually a tool call message | |
| if messages[msg_idx]["tool_calls"] is None: | |
| issues.append({ | |
| "step": step_idx, | |
| "error": "Expected tool call but message has no tool_calls", | |
| "expected_tools": list(ground_truth.keys()), | |
| "actual_tool": None, | |
| "message_index": msg_idx, | |
| }) | |
| continue | |
| tool_name = messages[msg_idx]["tool_calls"][0]["function"]["name"] | |
| # Check if tool name matches any expected tool in ground truth | |
| if tool_name not in ground_truth: | |
| issues.append({ | |
| "step": step_idx, | |
| "error": "Tool name mismatch", | |
| "expected_tools": list(ground_truth.keys()), | |
| "actual_tool": tool_name, | |
| "message_index": msg_idx, | |
| }) | |
| continue | |
| # Check if tool response matches ground truth | |
| tool_response_idx = msg_idx + 1 | |
| if tool_response_idx >= len(messages): | |
| issues.append({ | |
| "step": step_idx, | |
| "error": "Missing tool response", | |
| "tool_name": tool_name, | |
| "expected_value": ground_truth[tool_name], | |
| "actual_value": None, | |
| "message_index": tool_response_idx, | |
| }) | |
| continue | |
| tool_answer = parse_string(messages[tool_response_idx]["content"]) | |
| expected_answer = ground_truth[tool_name] | |
| if not verify_answers(expected_answer, tool_answer): | |
| issues.append({ | |
| "step": step_idx, | |
| "error": "Tool response value mismatch", | |
| "tool_name": tool_name, | |
| "expected_value": expected_answer, | |
| "actual_value": tool_answer, | |
| "message_index": tool_response_idx, | |
| }) | |
| return (len(issues) > 0, issues) | |
| def matches_exact_message_count(messages: list[dict], ground_truths: list[dict]) -> tuple[bool, str]: | |
| """ | |
| Check if the number of messages matches the expected message count. | |
| """ | |
| # Expected message count: 2 (user + system) + 2 * len(ground_truths) (tool calls and responses) + 1 (final answer) | |
| expected_message_count = len(ground_truths) * 2 + 3 | |
| return len(messages) == expected_message_count, f"Expected {expected_message_count} messages, got {len(messages)}" | |
| def print_messages_above_token_threshold( | |
| tokenizer: AutoTokenizer, | |
| ds: Dataset, | |
| token_threshold: int, | |
| ): | |
| """ | |
| Print all messages with token counts above the specified threshold. | |
| """ | |
| def collect_token_counts(ds, tokenizer): | |
| token_counts_by_role = {"user": [], "assistant": [], "tool": []} | |
| all_token_counts = [] | |
| messages_with_tokens = [] # Store (token_count, sample_idx, msg_idx, role, content) | |
| for i in range(len(ds)): | |
| messages = parse_messages(ds[i]["messages"]) | |
| for msg_idx, msg in enumerate(messages): | |
| content = get_message_content(msg) | |
| role = msg.get("role", "unknown") | |
| # Skip system messages | |
| if role == "system": | |
| continue | |
| # Count tokens using shared utility | |
| token_count = count_tokens_in_message(content, tokenizer) | |
| all_token_counts.append(token_count) | |
| messages_with_tokens.append((token_count, i, msg_idx, role, content)) | |
| if role in token_counts_by_role: | |
| token_counts_by_role[role].append(token_count) | |
| return all_token_counts, token_counts_by_role, messages_with_tokens | |
| _, _, messages_with_tokens = collect_token_counts(ds, tokenizer) | |
| # Print messages above token threshold | |
| print_to_both(f"\n{'=' * 80}") | |
| print_to_both(f"MESSAGES WITH MORE THAN {token_threshold} TOKENS (system messages excluded)") | |
| print_to_both(f"{'=' * 80}\n") | |
| # Filter messages above threshold and sort by token count (descending) | |
| # Note: system messages already excluded during collection | |
| high_token_messages = [msg for msg in messages_with_tokens if msg[0] > token_threshold] | |
| sorted_messages = sorted(high_token_messages, key=lambda x: x[0], reverse=True) | |
| print_to_both(f"Found {len(sorted_messages)} messages above {token_threshold} tokens\n") | |
| for rank, (token_count, sample_idx, msg_idx, role, content) in enumerate(sorted_messages, 1): | |
| print_to_both(f"\n{'─' * 80}") | |
| print_to_both( | |
| f"Rank #{rank} | Tokens: {token_count} | Sample: {sample_idx} | Message: {msg_idx} | Role: {role}" | |
| ) | |
| print_to_both(f"{'─' * 80}") | |
| print_to_both(content) | |
| print_to_both(f"\n{'=' * 80}") | |
| print_to_both(f"SUMMARY: {len(sorted_messages)} messages found with more than {token_threshold} tokens") | |
| print_to_both(f"{'=' * 80}") | |
| def analyze_dataset( # noqa: C901 | |
| dataset_name: str, | |
| config: str, | |
| split: str, | |
| _load_from_disk: bool = True, | |
| verbose: bool = True, | |
| minimal_dataset: bool = False, | |
| assistant_token_threshold: int | None = None, | |
| tokenizer_name: str = "Qwen/Qwen2.5-3B-Instruct", | |
| push_to_hub: bool = False, | |
| ): | |
| """Analyze dataset for answer reuse and duplicated thinking issues.""" | |
| print_to_both(f"Loading dataset: {dataset_name}/{config} ({split})") | |
| if _load_from_disk: | |
| ds = load_from_disk(dataset_name) | |
| else: | |
| ds = load_dataset(dataset_name, config, split=split, download_mode=DownloadMode.FORCE_REDOWNLOAD) | |
| print_to_both(f"Dataset size: {len(ds)}") | |
| print_to_both("Checking for answer reuse and duplicated thinking issues...") | |
| print_to_both() | |
| parser = XMLParser() | |
| # Load tokenizer if token threshold is specified | |
| tokenizer = None | |
| if assistant_token_threshold is not None: | |
| print_to_both(f"Loading tokenizer for token counting: {tokenizer_name}") | |
| tokenizer = load_tokenizer(tokenizer_name) | |
| # Track issues | |
| reuse_issues = [] | |
| all_think_duplicates = [] | |
| exact_sequence_issues = [] | |
| repeated_tool_calls = [] | |
| message_count_issues = [] | |
| token_threshold_issues = [] | |
| for idx in range(len(ds)): | |
| messages = parse_messages(ds[idx]["messages"]) | |
| stepwise_ground_truths = json.loads(ds[idx]["stepwise_ground_truths"]) | |
| # # Check 1: Answer must reference previous tool response | |
| has_reuse_issue, reuse_reason = check_tool_call_is_skipped_due_to_simple_op( | |
| messages, parser, minimal_dataset=minimal_dataset | |
| ) | |
| if has_reuse_issue: | |
| reuse_issues.append((idx, reuse_reason)) | |
| # Check 2: Repeated tool calls | |
| has_repeated, repeated_info = check_repeated_tool_calls(messages, parser) | |
| if has_repeated: | |
| repeated_tool_calls.append((idx, repeated_info)) | |
| # Check 3: Check exact sequence based on stepwise ground truths | |
| has_exact_sequence, exact_sequence_info = check_exact_sequence(messages, parser, stepwise_ground_truths) | |
| if has_exact_sequence and minimal_dataset: | |
| exact_sequence_issues.append((idx, exact_sequence_info)) | |
| # Check 4: Message count validation | |
| has_count_issue, count_reason = matches_exact_message_count(messages, stepwise_ground_truths) | |
| if not has_count_issue: | |
| message_count_issues.append((idx, count_reason)) | |
| # Check 5: Assistant token threshold | |
| if assistant_token_threshold is not None and tokenizer is not None: | |
| max_tokens = get_max_assistant_tokens(messages, tokenizer) | |
| if max_tokens > assistant_token_threshold: | |
| token_threshold_issues.append((idx, max_tokens)) | |
| # Check 6: ALL duplicate thinking across conversation (including non-adjacent) | |
| duplicates, modified_messages = check_all_think_duplicates(messages, parser) | |
| if duplicates: | |
| # Print before/after for verification | |
| print_to_both(f"\n{'=' * 80}") | |
| print_to_both(f"Example {idx}: Found {len(duplicates)} duplicate(s)") | |
| print_to_both(f"{'=' * 80}") | |
| for dup in duplicates: | |
| pos_i, pos_j = dup["positions"] | |
| print_to_both( | |
| f"\nDuplicate between positions {pos_i} and {pos_j} (adjacent: {dup['adjacent']}, last: {pos_j == len(messages) - 1})" | |
| ) | |
| is_answer = pos_j == len(messages) - 1 | |
| if verbose: | |
| analyse_indices("analysis_verbose", msgs=messages) | |
| else: | |
| # Show the second occurrence (which will be replaced) | |
| print_to_both(f"\n--- BEFORE (Message {pos_j}) ---") | |
| print_to_both(messages[pos_j]["content"]) # First 500 chars | |
| print_to_both(f"\n--- AFTER (Message {pos_j}) ---") | |
| print_to_both(modified_messages[pos_j]["content"]) # First 500 chars | |
| all_think_duplicates.append((idx, dup, is_answer)) | |
| # Report results | |
| print_to_both("=" * 80) | |
| print_to_both("ANALYSIS RESULTS") | |
| print_to_both("=" * 80) | |
| print_to_both( | |
| f"\n1. Answer-Tool Response Mismatch Issues: {len(reuse_issues)} ({len(reuse_issues) / len(ds) * 100:.2f}%)" | |
| ) | |
| if reuse_issues: | |
| for idx, reason in reuse_issues: | |
| print_to_both(f" - Example {idx}: {reason}") | |
| print_to_both( | |
| f"\n2. Repeated Tool Calls: {len(repeated_tool_calls)} ({len(repeated_tool_calls) / len(ds) * 100:.2f}%)" | |
| ) | |
| if repeated_tool_calls: | |
| for idx, repeated_info in repeated_tool_calls: | |
| for tool_info in repeated_info: | |
| print_to_both( | |
| f" - Example {idx}: '{tool_info['tool_name']}' called {tool_info['count']} times at positions {tool_info['positions']}" | |
| ) | |
| print_to_both( | |
| f"\n3. Exact Ground Truth Sequence Mismatches: {len(exact_sequence_issues)} ({len(exact_sequence_issues) / len(ds) * 100:.2f}%)" | |
| ) | |
| if exact_sequence_issues: | |
| for idx, issues_list in exact_sequence_issues: | |
| for issue in issues_list: | |
| print_to_both( | |
| f" - Example {idx} step {issue['step']}: {issue['error']} " | |
| f"(expected: {issue.get('expected_tools') or issue.get('expected_value')}, " | |
| f"got: {issue.get('actual_tool') or issue.get('actual_value')})" | |
| ) | |
| print_to_both( | |
| f"\n4. Message Count Mismatches: {len(message_count_issues)} ({len(message_count_issues) / len(ds) * 100:.2f}%)" | |
| ) | |
| if message_count_issues: | |
| for idx, reason in message_count_issues: | |
| print_to_both(f" - Example {idx}: {reason}") | |
| print_to_both( | |
| f"\n5. Assistant Token Threshold Exceeded: {len(token_threshold_issues)} ({len(token_threshold_issues) / len(ds) * 100:.2f}%)" | |
| ) | |
| if token_threshold_issues: | |
| if assistant_token_threshold is not None: | |
| print_to_both(f" Threshold: {assistant_token_threshold} tokens") | |
| for idx, max_tokens in token_threshold_issues: | |
| print_to_both(f" - Example {idx}: max {max_tokens} tokens") | |
| if verbose: | |
| print_messages_above_token_threshold(tokenizer, ds, assistant_token_threshold) | |
| # Comprehensive duplicate check | |
| adjacent_dups = [d for d in all_think_duplicates if d[1]["adjacent"]] | |
| non_adjacent_dups = [d for d in all_think_duplicates if not d[1]["adjacent"]] | |
| print_to_both("\n6. ALL Think Block Duplicates (comprehensive check):") | |
| print_to_both(f" Total: {len(all_think_duplicates)}") | |
| print_to_both(f" Adjacent (i → tool → j): {len(adjacent_dups)}") | |
| print_to_both(f" Non-adjacent: {len(non_adjacent_dups)}") | |
| if adjacent_dups: | |
| print_to_both("\n Adjacent duplicates:") | |
| for idx, dup, is_answer in adjacent_dups: | |
| print_to_both( | |
| f" Example {idx}, msgs {dup['positions']}, is answer: {is_answer}: {dup['content_preview']}..." | |
| ) | |
| if non_adjacent_dups: | |
| print_to_both("\n Non-adjacent duplicates:") | |
| for idx, dup, is_answer in non_adjacent_dups: | |
| print_to_both( | |
| f" Example {idx}, msgs {dup['positions']}, is answer: {is_answer}: {dup['content_preview']}..." | |
| ) | |
| print_to_both(f"Number of adjacent duplicates with final answer: {len([d for d in adjacent_dups if d[2]])}") | |
| # Combined issues | |
| all_issues = set( | |
| [idx for idx, _ in reuse_issues] | |
| + [idx for idx, _, _ in all_think_duplicates] | |
| + [idx for idx, _ in repeated_tool_calls] | |
| + [idx for idx, _ in exact_sequence_issues] | |
| + [idx for idx, _ in message_count_issues] | |
| + [idx for idx, _ in token_threshold_issues] | |
| ) | |
| print_to_both( | |
| f"\n7. Total Unique Examples with Issues: {len(all_issues)} ({len(all_issues) / len(ds) * 100:.2f}%)" | |
| ) | |
| print_to_both("\n" + "=" * 80) | |
| cleaned_ds = remove_issues(all_issues, ds) | |
| # Save or push cleaned dataset | |
| print_to_both("\n" + "=" * 80) | |
| print_to_both("Cleaned dataset ready.") | |
| print_to_both(f"Total examples removed: {len(all_issues)}") | |
| if push_to_hub: | |
| print_to_both(f"Pushing dataset to https://huggingface.co/datasets/{dataset_name}-clean") | |
| print_to_both(f"Dataset size: {len(cleaned_ds)}") | |
| cleaned_ds.push_to_hub(f"{dataset_name}-clean") | |
| print_to_both("=" * 80) | |
| return reuse_issues, all_think_duplicates, all_issues | |
| def remove_issues(all_issues, ds): | |
| if len(all_issues) > 0: | |
| # Create a mask for filtering | |
| keep_mask = [i not in all_issues for i in range(len(ds))] | |
| # Filter dataset | |
| cleaned_ds = ds.select([i for i, keep in enumerate(keep_mask) if keep]) | |
| print_to_both(f"\nInitial dataset size: {len(ds)}") | |
| print_to_both(f"\nCleaned dataset size: {len(cleaned_ds)}") | |
| print_to_both(f"Removed: {len(ds) - len(cleaned_ds)} examples") | |
| return cleaned_ds | |
| else: | |
| print_to_both("✅ No problematic examples found! Dataset is clean.") | |
| return ds | |
| def analyse_indices( # noqa: C901 | |
| analysis_type: str, indices: list[int] | None = None, ds: Dataset | None = None, msgs: list[dict] | None = None | |
| ) -> Dataset | None: | |
| def print_msgs(messages: list[dict]): | |
| for i, msg in enumerate(messages): | |
| print_to_both(f"\n--- Message {i} ({msg['role']}) ---") | |
| if msg["role"] == "user": | |
| print_to_both(msg["content"]) | |
| elif msg["role"] == "assistant": | |
| # Print content (truncate if too long) | |
| content = msg["content"] | |
| print_to_both(content) | |
| # Print tool calls if present | |
| if msg.get("tool_calls"): | |
| print_to_both("\nTOOL CALLS:") | |
| for tc in msg["tool_calls"]: | |
| print_to_both(f" - {tc['function']['name']}: {tc['function']['arguments']}") | |
| elif msg["role"] == "tool": | |
| print_to_both(f"RESULT: {msg['content']}") | |
| print_to_both("\n") | |
| if indices is not None and ds is not None: | |
| print_to_both("=" * 80) | |
| print_to_both(f"Analysing indices: {indices} for analysis type: {analysis_type}") | |
| print_to_both("=" * 80) | |
| for idx in indices: | |
| print_to_both("=" * 80) | |
| print_to_both(f"EXAMPLE {idx}") | |
| print_to_both("=" * 80) | |
| messages = json.loads(ds[idx]["messages"]) if isinstance(ds[idx]["messages"], str) else ds[idx]["messages"] | |
| print_msgs(messages) | |
| return ds | |
| elif msgs is not None: | |
| print_msgs(msgs) | |
| return None | |
| else: | |
| raise ValueError("Either indices and ds or msgs must be provided") | |
| def clean_dataset(dataset, settings): | |
| print_to_both("=" * 80) | |
| print_to_both("Cleaning dataset...") | |
| print_to_both("=" * 80) | |
| def remove_messages(example, idx): | |
| if idx in settings and "to_remove" in settings[idx]: | |
| if settings[idx]["initial_msg_count"] != len(json.loads(example["messages"])): | |
| # If msg count doesn't match, skip cleaning as it has already been cleaned | |
| current_count = len(json.loads(example["messages"])) | |
| assert settings[idx]["expected_final_msg_count"] == current_count, ( | |
| f"Index {idx}: Expected final message count does not match got message count" | |
| ) | |
| print_to_both( | |
| f"INFO: Did not modify index {idx} as it has already been cleaned. " | |
| f"Config initial message count: {settings[idx]['initial_msg_count']}. " | |
| f"Current message count matches expected message count: {current_count} == {settings[idx]['expected_final_msg_count']}." | |
| ) | |
| return example | |
| to_remove = settings[idx]["to_remove"] | |
| msgs = copy.deepcopy(json.loads(example["messages"])) | |
| clean_msgs = [msg for i, msg in enumerate(msgs) if i not in to_remove] | |
| # Validation | |
| assert len(clean_msgs) == len(msgs) - len(to_remove), ( | |
| f"Index {idx}: Number of messages to remove does not match" | |
| ) | |
| remove_reason = settings[idx]["remove_reason"] | |
| print_to_both(f"""NOTE: Message {idx}.""") | |
| print_to_both(f"""Removed messages with indices: {to_remove}""") | |
| print_to_both(f"Remove reason: {remove_reason}") | |
| updated_example = example.copy() | |
| updated_example["messages"] = json.dumps(clean_msgs) | |
| return updated_example | |
| return example | |
| def update_messages(example, idx): | |
| if idx in settings and "to_replace" in settings[idx]: | |
| if settings[idx]["initial_msg_count"] != len(json.loads(example["messages"])): | |
| # If msg count doesn't match, skip cleaning as it has already been cleaned | |
| current_count = len(json.loads(example["messages"])) | |
| assert settings[idx]["expected_final_msg_count"] == current_count, ( | |
| f"Index {idx}: Expected final message count does not match got message count" | |
| ) | |
| return example | |
| to_replace = settings[idx]["to_replace"] | |
| # Parse JSON once at the beginning | |
| messages = json.loads(example["messages"]) | |
| # Process all indices for this example | |
| for index, replacement in to_replace.items(): | |
| msg_idx = int(index) | |
| current_content = messages[msg_idx]["content"] | |
| # Parse and replace specific tags | |
| for tag, new_content in replacement["content"].items(): | |
| # Use regex to find and replace the specific tag | |
| tag_pattern = rf"<{re.escape(tag)}>\s*.*?\s*</{re.escape(tag)}>" | |
| current_content = re.sub(tag_pattern, new_content, current_content, flags=re.DOTALL) | |
| messages[msg_idx]["content"] = current_content | |
| # Serialize JSON once at the end | |
| example["messages"] = json.dumps(messages) | |
| return example | |
| # We update first, then remove redundant messages. This ensures the indices | |
| # remain consistent in the yaml config file. | |
| dataset = dataset.map(update_messages, with_indices=True) | |
| dataset = dataset.map(remove_messages, with_indices=True) | |
| return dataset | |
| def check_integrity(dataset: Dataset, indices: list[int], cleaning_config: dict): | |
| for idx in indices: | |
| messages = json.loads(dataset[idx]["messages"]) | |
| assert all(messages[i]["role"] != messages[i - 1]["role"] for i in range(1, len(messages))), ( | |
| "Messages are not in the correct order" | |
| ) | |
| assert len(messages) == cleaning_config[idx]["expected_final_msg_count"], ( | |
| f"Index {idx}: Number of final messages does not match expected number" | |
| ) | |
| def load_cleaning_config(config_path: Path) -> dict: | |
| """Load cleaning configuration from YAML file.""" | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| # Convert string keys back to tuples for remove_reason | |
| for operations in config.values(): | |
| if "remove_reason" in operations: | |
| operations["remove_reason"] = { | |
| tuple(map(int, k.split(","))): v for k, v in operations["remove_reason"].items() | |
| } | |
| return config | |
| def check_should_print_info(dataset: Dataset, cleaning_config: dict) -> tuple[list[int], list[int]]: | |
| """Check if should print info for the dataset.""" | |
| all_indices = list(cleaning_config) | |
| dirty_indices = [] | |
| for idx in all_indices: | |
| # If the current message count is equal to the config initial message count, | |
| # this means that it has not yet been processed | |
| if cleaning_config[idx]["initial_msg_count"] == len(json.loads(dataset[idx]["messages"])): | |
| dirty_indices.append(idx) | |
| print_to_both(f"Processing {len(dirty_indices)} indices for cleaning...") | |
| print_to_both(f"Will print info only for indices: {dirty_indices}") | |
| print_to_both(f"All indices: {all_indices}") | |
| return all_indices, dirty_indices | |
| def process_dataset_cleaning(dataset, config_path: Path): | |
| """Apply cleaning operations from configuration file.""" | |
| cleaning_config = load_cleaning_config(config_path) | |
| all_indices, dirty_indices = check_should_print_info(dataset, cleaning_config) | |
| dataset = analyse_indices("original", indices=all_indices, ds=dataset) | |
| cleaned_dataset = clean_dataset(dataset, cleaning_config) | |
| # Notice that here we print only the dirty indices that have been cleaned | |
| cleaned_dataset = analyse_indices("cleaned", indices=dirty_indices, ds=cleaned_dataset) | |
| # Check integrity against ALL indices | |
| check_integrity(cleaned_dataset, indices=all_indices, cleaning_config=cleaning_config) | |
| return cleaned_dataset, cleaning_config | |
| def save_dataset_to_disk(dataset: Dataset, path: str): | |
| """Save the dataset to a directory.""" | |
| # Save to temporary location | |
| temp_path = f"{path}_temp" | |
| dataset.save_to_disk(temp_path) | |
| # Remove old dataset and rename temp | |
| shutil.rmtree(path) | |
| shutil.move(temp_path, path) | |
| def inspect(indices: list[int], ds: Dataset): | |
| analyse_indices("original", indices=indices, ds=ds) | |
| if __name__ == "__main__": | |
| # TODO: check for duplicate function calls and fix them | |
| inspect_indices = False | |
| analyse = False | |
| local_dataset = False | |
| commit = False | |
| if inspect_indices: | |
| dataset_path = local_dataset_path if local_dataset else "atomwalk12/linalgzero-distilled" | |
| dataset = load_dataset(dataset_path, "default", split="train") | |
| inspect(indices=[177, 285], ds=dataset) | |
| sys.exit() | |
| # Set to None to disable token threshold check, or set to a number (e.g., 1000) to enable | |
| assistant_token_threshold = None | |
| dataset_path = local_dataset_path if local_dataset else "atomwalk12/linalgzero-distilled" | |
| if analyse: | |
| output_file = "analyse_indices.txt" | |
| dataset = load_from_disk(dataset_path) | |
| cleaned_dataset, config = process_dataset_cleaning(dataset, Path("linalg_zero/config/cleaning_config.yaml")) | |
| assert cleaned_dataset is not None, "Cleaned dataset is None" | |
| if commit: | |
| save_dataset_to_disk(cleaned_dataset, path=local_dataset_path) | |
| else: | |
| print_to_both("Not committing dataset to disk.") | |
| else: | |
| output_file = "analyse_dataset.txt" | |
| reuse_issues, all_think_duplicates, all_issues = analyze_dataset( | |
| dataset_path, | |
| "default", | |
| "train", | |
| _load_from_disk=local_dataset, | |
| verbose=True, | |
| minimal_dataset=True, | |
| assistant_token_threshold=800, | |
| push_to_hub=False, | |
| ) | |
| print_to_both(f"Wrote data to file: {output_file}") | |
| # Write buffer to file | |
| with open(output_file, "w") as f: | |
| f.write(output_buffer.getvalue()) | |