File size: 10,754 Bytes
5c93746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""
Model data transfer abstraction layer.

This module provides high-level interfaces for transferring model data
between distributed ranks during inference.
"""

import torch
from typing import List, Tuple, Optional, Any
import logging
from .distributed_communicator import DistributedCommunicator
from .buffer_manager import BufferManager
from .kv_cache_manager import KVCacheManager
from .data_containers import LatentData, CommunicationConfig, PerformanceMetrics
from .utils import CommunicationTimer


class ModelDataTransfer:
    """
    High-level interface for model data transfer operations.
    
    This class encapsulates all model-related data transfer operations,
    providing a clean interface for sending and receiving latent data,
    KV caches, and other model state between ranks.
    """
    
    def __init__(self, communicator: DistributedCommunicator, 
                 buffer_manager: BufferManager,
                 kv_cache_manager: Optional[KVCacheManager] = None,
                 config: Optional[CommunicationConfig] = None):
        """
        Initialize the model data transfer manager.
        
        Args:
            communicator: Distributed communicator instance
            buffer_manager: Buffer manager for tensor allocation
            kv_cache_manager: KV cache manager (optional)
            config: Communication configuration
        """
        self.comm = communicator
        self.buffer_mgr = buffer_manager
        self.kv_cache_mgr = kv_cache_manager
        self.config = config or CommunicationConfig()
        
        # Setup logging
        self.logger = logging.getLogger(f"ModelDataTransfer_rank_{communicator.rank}")
        self.logger.propagate = False
        if not self.logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                f'[Rank {communicator.rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
        
        # Performance tracking
        self.transfer_count = 0
        self.total_transfer_time = 0.0
    
    def send_latent_data_async(self, chunk_idx: int, latents: torch.Tensor,
                             original_latents: torch.Tensor, patched_x_shape: torch.Tensor,
                             current_start: torch.Tensor, current_end: torch.Tensor,
                             current_step: int) -> List[Any]:
        """
        Asynchronously send latent data to the next rank.
        
        Args:
            chunk_idx: Chunk index
            latents: Latent tensor
            original_latents: Original latent tensor
            patched_x_shape: Patched x shape tensor
            current_start: Current start indices
            current_end: Current end indices
            current_step: Current step
            
        Returns:
            List of work objects for all send operations
        """
        with CommunicationTimer(f"send_latent_data_async chunk_{chunk_idx}", self.logger):
            work_objects = self.comm.send_latent_data_async(
                chunk_idx=chunk_idx,
                latents=latents,
                original_latents=original_latents,
                patched_x_shape=patched_x_shape,
                current_start=current_start,
                current_end=current_end,
                current_step=current_step
            )
        
        self.transfer_count += 1
        self.logger.debug(f"Sent latent data for chunk {chunk_idx}")
        return work_objects
    
    def receive_latent_data_async(self, num_steps: int) -> LatentData:
        """
        Asynchronously receive latent data from the previous rank.
        
        Args:
            num_steps: Number of denoising steps
            
        Returns:
            LatentData object containing all received data
        """
        with CommunicationTimer("receive_latent_data_async", self.logger):
            chunk_idx, latents, original_latents, current_start, current_end, current_step, patched_x_shape = \
                self.comm.recv_latent_data_async(num_steps, self.buffer_mgr)
        
        self.transfer_count += 1
        self.logger.debug(f"Received latent data for chunk {chunk_idx}")
        
        return LatentData(
            chunk_idx=chunk_idx,
            latents=latents,
            original_latents=original_latents,
            current_start=current_start,
            current_end=current_end,
            current_step=current_step,
            patched_x_shape=patched_x_shape
        )

    def release_latent_data(self, latent_data: Optional[LatentData]) -> None:
        """Return received latent-data buffers to the buffer pool."""
        if latent_data is None or self.buffer_mgr is None:
            return

        self.buffer_mgr.return_buffer(latent_data.latents, "latent")
        self.buffer_mgr.return_buffer(latent_data.original_latents, "origin")
        self.buffer_mgr.return_buffer(latent_data.patched_x_shape, "misc")
        self.buffer_mgr.return_buffer(latent_data.current_start, "misc")
        self.buffer_mgr.return_buffer(latent_data.current_end, "misc")

    def send_prompt_async(self, prompt: str, device: torch.device) -> List[Any]:
        return self.comm.send_prompt_async(prompt, device)

    def recv_prompt_async(self) -> str:
        return self.comm.recv_prompt_async()
    
    def send_kv_cache_blocks(self, block_indices: List[int], donor_rank: int) -> None:
        """
        Send KV cache blocks to all ranks.
        
        Args:
            block_indices: List of block indices to send
            donor_rank: Rank that owns the KV cache data
        """
        if self.kv_cache_mgr is None:
            raise RuntimeError("KV cache manager not initialized")
        
        with CommunicationTimer(f"send_kv_cache_blocks {len(block_indices)} blocks", self.logger):
            self.kv_cache_mgr.broadcast_kv_blocks(block_indices, donor_rank)
        
        self.logger.debug(f"Sent KV cache blocks {block_indices} from rank {donor_rank}")
    
    def rebalance_kv_cache(self, old_intervals: torch.Tensor, 
                          new_intervals: torch.Tensor, total_blocks: int) -> None:
        """
        Rebalance KV cache ownership based on new block intervals.
        
        Args:
            old_intervals: Previous block intervals [world_size, 2]
            new_intervals: New block intervals [world_size, 2]
            total_blocks: Total number of blocks
        """
        if self.kv_cache_mgr is None:
            raise RuntimeError("KV cache manager not initialized")
        
        with CommunicationTimer("rebalance_kv_cache", self.logger):
            self.kv_cache_mgr.rebalance_kv_cache_by_diff(old_intervals, new_intervals, total_blocks)
        
        self.logger.info("Rebalanced KV cache ownership")
    
    def broadcast_tensor(self, tensor: torch.Tensor, src: int) -> None:
        """
        Broadcast a tensor from source to all ranks.
        
        Args:
            tensor: Tensor to broadcast
            src: Source rank
        """
        with CommunicationTimer(f"broadcast_tensor from rank {src}", self.logger):
            self.comm.broadcast_tensor(tensor, src)
        
        self.logger.debug(f"Broadcasted tensor from rank {src}, shape: {tensor.shape}")
    
    def all_gather_tensors(self, tensor: torch.Tensor) -> List[torch.Tensor]:
        """
        Gather tensors from all ranks.
        
        Args:
            tensor: Local tensor to gather
            
        Returns:
            List of tensors from all ranks
        """
        with CommunicationTimer("all_gather_tensors", self.logger):
            gather_list = self.comm.all_gather_tensors(tensor)
        
        self.logger.debug(f"Gathered tensors from all ranks, local shape: {tensor.shape}")
        return gather_list
    
    def wait_for_outstanding(self, max_outstanding: Optional[int] = None) -> None:
        """
        Wait for outstanding operations to complete.
        
        Args:
            max_outstanding: Maximum number of outstanding operations to keep
        """
        with CommunicationTimer("wait_for_outstanding", self.logger):
            self.comm.wait_for_outstanding(max_outstanding)
    
    def barrier(self) -> None:
        """Synchronize all ranks."""
        with CommunicationTimer("barrier", self.logger):
            self.comm.barrier()
    
    def get_performance_metrics(self) -> PerformanceMetrics:
        """
        Get performance metrics for data transfer operations.
        
        Returns:
            PerformanceMetrics object containing timing information
        """
        # This is a simplified version - in practice, you'd want to track
        # more detailed timing information
        avg_transfer_time = self.total_transfer_time / max(1, self.transfer_count)
        
        return PerformanceMetrics(
            dit_time=0.0,  # Would be filled by caller
            total_time=0.0,  # Would be filled by caller
            communication_time=avg_transfer_time,
            buffer_allocation_time=0.0  # Would be tracked by buffer manager
        )
    
    def get_statistics(self) -> dict:
        """
        Get transfer statistics.
        
        Returns:
            Dictionary containing transfer statistics
        """
        return {
            "transfer_count": self.transfer_count,
            "total_transfer_time": self.total_transfer_time,
            "avg_transfer_time": self.total_transfer_time / max(1, self.transfer_count),
            "communicator_stats": self.comm.get_statistics(),
            "buffer_manager_stats": self.buffer_mgr.get_statistics() if self.buffer_mgr else None
        }
    
    def print_statistics(self) -> None:
        """Print transfer statistics."""
        stats = self.get_statistics()
        self.logger.info("Model Data Transfer Statistics:")
        for key, value in stats.items():
            if key == "communicator_stats" or key == "buffer_manager_stats":
                if value:
                    self.logger.info(f"  {key}:")
                    for sub_key, sub_value in value.items():
                        self.logger.info(f"    {sub_key}: {sub_value}")
            else:
                self.logger.info(f"  {key}: {value}")
    
    def cleanup(self) -> None:
        """Clean up resources."""
        if self.buffer_mgr:
            self.buffer_mgr.clear_buffers()
        self.logger.info("Model data transfer cleanup completed")
    
    def __del__(self):
        """Cleanup when the transfer manager is destroyed."""
        try:
            self.cleanup()
        except Exception:
            pass