Spaces:
Running
Running
| import json | |
| import os | |
| import argparse | |
| import re | |
| from collections import defaultdict | |
| from src.execution_model import Schedule, ScheduleConfig, Operation | |
| from src.visualizer import visualize_pipeline_parallelism_dash | |
| def is_valid_event_filename(filename, pp_size, vpp_size): | |
| """ | |
| Check if filename matches the expected format: | |
| event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_{rank}_pp_rank_{pp_rank}_rank_{final_rank}.json | |
| Returns True if valid, False otherwise. | |
| """ | |
| # Create regex pattern for the expected format | |
| pattern = rf"^event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_\d+_pp_rank_\d+_rank_\d+\.json$" | |
| return bool(re.match(pattern, filename)) | |
| def parse_event_filename(filename): | |
| """ | |
| Parse the event filename and extract rank information. | |
| Expected format: event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_{rank}_pp_rank_{pp_rank}_rank_{final_rank}.json | |
| Returns: (TPxCPxDP_rank, pp_rank, global_rank) or None if parsing fails | |
| """ | |
| try: | |
| # Remove .json extension | |
| name_without_ext = filename.replace(".json", "") | |
| parts = name_without_ext.split("_") | |
| # Find the TPxCPxDP part and the rank values | |
| tpxcpxdp_rank = None | |
| pp_rank = None | |
| global_rank = None | |
| for i, part in enumerate(parts): | |
| # Look for TPxCPxDP pattern followed by 'rank' | |
| if part.startswith("TP") and "CP" in part and part.endswith("DP"): | |
| if i + 2 < len(parts) and parts[i + 1] == "rank": | |
| tpxcpxdp_rank = int(parts[i + 2]) | |
| # Look for 'pp_rank' pattern | |
| elif part == "pp" and i + 2 < len(parts) and parts[i + 1] == "rank": | |
| pp_rank = int(parts[i + 2]) | |
| # Look for the final 'rank' (global rank) - this should be the last rank in the filename | |
| elif part == "rank" and i + 1 < len(parts) and i == len(parts) - 2: | |
| global_rank = int(parts[i + 1]) | |
| if tpxcpxdp_rank is None or pp_rank is None or global_rank is None: | |
| return None | |
| return (tpxcpxdp_rank, pp_rank, global_rank) | |
| except (ValueError, IndexError): | |
| return None | |
| def load_event_times_from_json(data_dir, pp_size, vpp_size): | |
| """Load event times from JSON files in the specified directory.""" | |
| all_files = [f for f in os.listdir(data_dir) if f.endswith(".json")] | |
| # Filter files that match the expected format | |
| event_files = [ | |
| f for f in all_files if is_valid_event_filename(f, pp_size, vpp_size) | |
| ] | |
| if len(event_files) == 0: | |
| print(f"Available files in {data_dir}:") | |
| for f in all_files[:10]: # Show first 10 files for debugging | |
| print(f" {f}") | |
| raise ValueError( | |
| f"No event files found matching pattern event_times_PP{pp_size}_VPP{vpp_size}_*.json" | |
| ) | |
| print(f"Found {len(event_files)} matching event files") | |
| event_times = {} | |
| for file_name in event_files: | |
| parsed_result = parse_event_filename(file_name) | |
| if parsed_result is None: | |
| print(f"Warning: Could not parse filename {file_name}") | |
| continue | |
| tpxcpxdp_rank, pp_rank, global_rank = parsed_result | |
| if tpxcpxdp_rank == 0: | |
| try: | |
| with open(os.path.join(data_dir, file_name), "r") as f: | |
| event_data = json.load(f) | |
| event_times[(pp_rank, tpxcpxdp_rank)] = event_data | |
| print( | |
| f"Loaded data from {file_name}: global_rank={global_rank}, pp_rank={pp_rank}, tpxcpxdp_rank={tpxcpxdp_rank}" | |
| ) | |
| except Exception as e: | |
| print(f"Error loading {file_name}: {e}") | |
| return event_times | |
| def create_pp_schedule_from_event_times(event_times, pp_size): | |
| """Create a Schedule object from event times data.""" | |
| # Determine number of devices/stages from the data | |
| num_devices = pp_size | |
| # Find the maximum batch ID by parsing event names | |
| max_batch_id = 0 | |
| for events in event_times.values(): | |
| for event_name in events: | |
| if event_name.startswith(("forward-", "backward-")): | |
| parts = event_name.split("-") | |
| if len(parts) >= 2 and parts[1].isdigit(): | |
| batch_id = int(parts[1]) | |
| max_batch_id = max(max_batch_id, batch_id) | |
| num_batches = max_batch_id + 1 | |
| # Create a simple config (actual times will come from event data) | |
| config = ScheduleConfig( | |
| num_devices=num_devices, | |
| num_stages=num_devices, # Assuming 1:1 mapping of devices to stages | |
| num_batches=num_batches, | |
| p2p_latency=0, # Will be implicit in the event timing | |
| op_times={}, # Not needed as we'll use real timing data | |
| placement_strategy="standard", | |
| ) | |
| # Create a schedule | |
| schedule = Schedule(config) | |
| # Populate the schedule with operations based on event times | |
| for (pp_rank, tpxcpxdp_rank), events in event_times.items(): | |
| # Process forward passes | |
| for batch_id in range(num_batches): | |
| forward_start_key = f"forward-{batch_id}-start" | |
| forward_end_key = f"forward-{batch_id}-end" | |
| if forward_start_key in events and forward_end_key in events: | |
| # Create an operation and set its timing directly | |
| forward_op = Operation(batch_id, pp_rank, "forward") | |
| forward_op.execution_time = ( | |
| events[forward_end_key] - events[forward_start_key] | |
| ) | |
| forward_op.start_time = events[forward_start_key] | |
| forward_op.end_time = events[forward_end_key] | |
| # Add to schedule | |
| schedule.ops[(batch_id, pp_rank, "forward")] = forward_op | |
| schedule.device_queues[pp_rank].add_operation(forward_op) | |
| # Process backward passes | |
| for batch_id in range(num_batches): | |
| backward_start_key = f"backward-{batch_id}-start" | |
| backward_end_key = f"backward-{batch_id}-end" | |
| if backward_start_key in events and backward_end_key in events: | |
| # Create an operation and set its timing directly | |
| backward_op = Operation(batch_id, pp_rank, "backward") | |
| backward_op.execution_time = ( | |
| events[backward_end_key] - events[backward_start_key] | |
| ) | |
| backward_op.start_time = events[backward_start_key] | |
| backward_op.end_time = events[backward_end_key] | |
| # Add to schedule | |
| schedule.ops[(batch_id, pp_rank, "backward")] = backward_op | |
| schedule.device_queues[pp_rank].add_operation(backward_op) | |
| return schedule | |
| def create_vpp_schedule_from_event_times(event_times, pp_size, vpp_size): | |
| """Create a VPP Schedule object from event times data.""" | |
| # Determine number of devices/stages from the data | |
| # Find the maximum batch ID by parsing event names | |
| max_batch_id = 0 | |
| for events in event_times.values(): | |
| for event_name in events: | |
| if event_name.startswith(("forward-", "backward-")): | |
| parts = event_name.split("-") | |
| assert len(parts) == 4 | |
| assert parts[0] in ["forward", "backward"] | |
| assert parts[1].isdigit() and parts[2].isdigit() | |
| assert parts[3] in ["start", "end"] | |
| batch_id = int(parts[2]) # backward-0-19-end | |
| max_batch_id = max(max_batch_id, batch_id) | |
| num_batches = max_batch_id + 1 | |
| # Create a simple config (actual times will come from event data) | |
| config = ScheduleConfig( | |
| num_devices=pp_size, | |
| num_stages=pp_size * vpp_size, | |
| num_batches=num_batches, | |
| p2p_latency=0, # Will be implicit in the event timing | |
| op_times={}, # Not needed as we'll use real timing data | |
| placement_strategy="interleave", | |
| ) | |
| # Create a schedule | |
| schedule = Schedule(config) | |
| # Populate the schedule with operations based on event times | |
| for (pp_rank, tpxcpxdp_rank), events in event_times.items(): | |
| # Process forward passes | |
| for model_chunk_id in range(vpp_size): | |
| for batch_id in range(num_batches): | |
| forward_start_key = f"forward-{model_chunk_id}-{batch_id}-start" | |
| forward_end_key = f"forward-{model_chunk_id}-{batch_id}-end" | |
| # Create an operation and set its timing directly | |
| stage_id = pp_size * model_chunk_id + pp_rank | |
| forward_op = Operation(batch_id, stage_id=stage_id, op_type="forward") | |
| forward_op.execution_time = ( | |
| events[forward_end_key] - events[forward_start_key] | |
| ) | |
| forward_op.start_time = events[forward_start_key] | |
| forward_op.end_time = events[forward_end_key] | |
| # Add to schedule | |
| schedule.ops[(batch_id, stage_id, "forward")] = forward_op | |
| schedule.device_queues[pp_rank].add_operation(forward_op) | |
| # Process backward passes | |
| for model_chunk_id in range(vpp_size): | |
| for batch_id in range(num_batches): | |
| backward_start_key = f"backward-{model_chunk_id}-{batch_id}-start" | |
| backward_end_key = f"backward-{model_chunk_id}-{batch_id}-end" | |
| stage_id = pp_size * model_chunk_id + pp_rank | |
| if backward_start_key in events and backward_end_key in events: | |
| # Create an operation and set its timing directly | |
| backward_op = Operation( | |
| batch_id, stage_id=stage_id, op_type="backward" | |
| ) | |
| backward_op.execution_time = ( | |
| events[backward_end_key] - events[backward_start_key] | |
| ) | |
| backward_op.start_time = events[backward_start_key] | |
| backward_op.end_time = events[backward_end_key] | |
| # Add to schedule | |
| schedule.ops[(batch_id, stage_id, "backward")] = backward_op | |
| schedule.device_queues[pp_rank].add_operation(backward_op) | |
| return schedule | |
| def main(): | |
| # Parse command-line arguments | |
| parser = argparse.ArgumentParser( | |
| description="Visualize pipeline parallelism from event data" | |
| ) | |
| parser.add_argument( | |
| "--data-dir", | |
| type=str, | |
| required=True, | |
| help="Directory containing event_times_*.json files", | |
| ) | |
| parser.add_argument( | |
| "--pp-size", type=int, required=True, help="Pipeline parallelism size" | |
| ) | |
| parser.add_argument( | |
| "--vpp-size", type=int, required=True, help="Virtual pipeline parallelism size" | |
| ) | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=8050, | |
| help="Port for the visualization dashboard (default: 8050)", | |
| ) | |
| args = parser.parse_args() | |
| # Load event times from JSON files | |
| event_times = load_event_times_from_json(args.data_dir, args.pp_size, args.vpp_size) | |
| # Create schedule from event times | |
| if args.vpp_size == 1: | |
| schedule = create_pp_schedule_from_event_times(event_times, args.pp_size) | |
| else: | |
| schedule = create_vpp_schedule_from_event_times( | |
| event_times, args.pp_size, args.vpp_size | |
| ) | |
| # Calculate and print execution metrics | |
| total_execution_time = max( | |
| op.end_time for op in schedule.ops.values() if op.end_time is not None | |
| ) | |
| print(f"Total execution time: {total_execution_time:.2f} ms") | |
| # Calculate bubble time percentage | |
| device_times = defaultdict(float) | |
| for device_id, device_queue in enumerate(schedule.device_queues): | |
| for op in device_queue.ops: | |
| if op.start_time is not None and op.end_time is not None: | |
| device_times[device_id] += op.end_time - op.start_time | |
| # Print bubble percentage for each device | |
| for device_id, active_time in device_times.items(): | |
| bubble_percentage = ( | |
| (total_execution_time - active_time) / total_execution_time * 100 | |
| ) | |
| print(f"Device {device_id} bubble: {bubble_percentage:.2f}%") | |
| # Visualize the schedule | |
| print("Launching visualization...") | |
| visualize_pipeline_parallelism_dash( | |
| schedule, schedule_type="1F1B-Imported", port=args.port | |
| ) | |
| if __name__ == "__main__": | |
| main() | |