File size: 17,799 Bytes
af68acb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
"""
Dynamic Load Balancer for SACCP Network
Distributes tasks across different node types based on availability, capacity, and performance
"""

import time
import heapq
from typing import Dict, List, Optional, Any, Tuple
from enum import Enum
from dataclasses import dataclass
from datetime import datetime, timedelta
import threading
import random


class TaskPriority(Enum):
    LOW = 1
    NORMAL = 2
    HIGH = 3
    CRITICAL = 4


class NodeType(Enum):
    HEAD = "head"
    RAM = "ram"
    DISK = "disk"
    COMPUTE = "compute"
    GPU = "gpu"
    TPU = "tpu"
    NPU = "npu"


@dataclass
class Task:
    """Represents a task to be distributed"""
    task_id: str
    task_type: str
    priority: TaskPriority
    resource_requirements: Dict[str, Any]  # CPU, memory, etc.
    estimated_duration: float  # in seconds
    created_at: float
    assigned_node: Optional[str] = None
    assigned_at: Optional[float] = None


@dataclass
class Node:
    """Represents a node in the network"""
    node_id: str
    node_type: NodeType
    capabilities: Dict[str, Any]  # CPU, memory, etc.
    current_load: float
    tasks_queued: int
    tasks_completed: int
    tasks_failed: int
    last_heartbeat: float
    performance_score: float  # 0.0-1.0 based on historical performance
    is_available: bool = True
    max_concurrent_tasks: int = 10
    current_tasks: int = 0


class LoadBalancer:
    """
    Dynamic load balancer that distributes tasks across node types
    """
    
    def __init__(self):
        self.nodes: Dict[str, Node] = {}
        self.task_queue: List[Tuple[int, float, Task]] = []  # Priority queue: (-priority, creation_time, task)
        self.assigned_tasks: Dict[str, str] = {}  # task_id -> node_id
        self.node_stats: Dict[str, Dict[str, Any]] = {}
        self.lock = threading.Lock()
        
        # Configuration
        self.heartbeat_timeout = 90  # seconds
        self.task_timeout = 300  # seconds
        self.load_balancing_algorithm = "weighted_least_connections"
    
    def register_node(self, node_id: str, node_type: NodeType, capabilities: Dict[str, Any]) -> bool:
        """Register a node with the load balancer"""
        with self.lock:
            self.nodes[node_id] = Node(
                node_id=node_id,
                node_type=node_type,
                capabilities=capabilities,
                current_load=0.0,
                tasks_queued=0,
                tasks_completed=0,
                tasks_failed=0,
                last_heartbeat=time.time(),
                performance_score=0.8,  # Default performance score
                max_concurrent_tasks=capabilities.get("max_concurrent_tasks", 10)
            )
            
            # Initialize node stats
            self.node_stats[node_id] = {
                "avg_task_duration": 0,
                "success_rate": 1.0,
                "response_time_avg": 0.1
            }
            
            return True
    
    def heartbeat_node(self, node_id: str) -> bool:
        """Update node heartbeat"""
        with self.lock:
            if node_id in self.nodes:
                self.nodes[node_id].last_heartbeat = time.time()
                self.nodes[node_id].is_available = True
                return True
            return False
    
    def heartbeat_batch_nodes(self, node_ids: List[str]) -> int:
        """Update heartbeats for multiple nodes"""
        count = 0
        for node_id in node_ids:
            if self.heartbeat_node(node_id):
                count += 1
        return count
    
    def deregister_node(self, node_id: str) -> bool:
        """Remove a node from the load balancer"""
        with self.lock:
            if node_id in self.nodes:
                # Move assigned tasks to queue for reassignment
                self._reassign_node_tasks(node_id)
                del self.nodes[node_id]
                if node_id in self.node_stats:
                    del self.node_stats[node_id]
                return True
            return False
    
    def submit_task(self, task: Task) -> Optional[str]:
        """Submit a task for distribution"""
        with self.lock:
            # Add task to priority queue
            # Priority: Higher priority first, then oldest first
            priority_key = (-task.priority.value, task.created_at)
            heapq.heappush(self.task_queue, priority_key + (task,))
            
            # Try to assign the task immediately
            node_id = self._find_suitable_node(task)
            if node_id:
                assigned = self._assign_task_to_node(task.task_id, node_id)
                if assigned:
                    return node_id
            return None  # Task queued but not yet assigned
    
    def get_task_assignment(self, task_id: str) -> Optional[str]:
        """Get the node assigned to a task"""
        with self.lock:
            return self.assigned_tasks.get(task_id)
    
    def complete_task(self, task_id: str, node_id: str, success: bool = True, duration: float = 0) -> bool:
        """Mark a task as completed"""
        with self.lock:
            # Update node stats
            if node_id in self.nodes:
                node = self.nodes[node_id]
                if success:
                    node.tasks_completed += 1
                    node.current_tasks -= 1
                else:
                    node.tasks_failed += 1
                    node.current_tasks -= 1
                
                # Update task queue count
                node.tasks_queued = max(0, node.tasks_queued - 1)
                
                # Update node stats for performance calculation
                if node_id in self.node_stats:
                    stats = self.node_stats[node_id]
                    if success and duration > 0:
                        # Update average task duration
                        if stats["avg_task_duration"] == 0:
                            stats["avg_task_duration"] = duration
                        else:
                            stats["avg_task_duration"] = (
                                stats["avg_task_duration"] * 0.7 + duration * 0.3
                            )
                        
                        # Update success rate
                        total_tasks = node.tasks_completed + node.tasks_failed
                        if total_tasks > 0:
                            stats["success_rate"] = node.tasks_completed / total_tasks
                
                # Update node performance score
                self._update_node_performance_score(node_id)
            
            # Remove from assigned tasks
            if task_id in self.assigned_tasks:
                del self.assigned_tasks[task_id]
                
                # Try to assign new tasks to available nodes
                self._attempt_task_assignments()
                
            return True
    
    def _find_suitable_node(self, task: Task) -> Optional[str]:
        """Find the most suitable node for a task"""
        with self.lock:
            # Get all available nodes
            available_nodes = [
                node for node in self.nodes.values()
                if self._is_node_suitable(node, task)
            ]
            
            if not available_nodes:
                return None
            
            # Sort nodes by the selected algorithm
            if self.load_balancing_algorithm == "weighted_least_connections":
                # Prioritize nodes with fewer connections and higher performance
                available_nodes.sort(key=lambda n: (
                    n.current_tasks / n.max_concurrent_tasks,  # Load factor
                    -n.performance_score  # Higher performance first
                ))
            elif self.load_balancing_algorithm == "weighted_response_time":
                # Prioritize nodes with better historical response time
                available_nodes.sort(key=lambda n: (
                    -n.performance_score,  # Higher performance first
                    n.current_tasks / n.max_concurrent_tasks  # Lower load first
                ))
            elif self.load_balancing_algorithm == "node_type_priority":
                # Prioritize specific node type for the task
                preferred_type = task.resource_requirements.get("preferred_node_type")
                available_nodes.sort(key=lambda n: (
                    0 if n.node_type.value == preferred_type else 1,  # Preferred type first
                    n.current_tasks / n.max_concurrent_tasks,  # Then lower load
                    -n.performance_score  # Then higher performance
                ))
            else:
                # Default: least connections with performance consideration
                available_nodes.sort(key=lambda n: (
                    n.current_tasks / n.max_concurrent_tasks,
                    -n.performance_score
                ))
            
            # Return the best node (first in sorted list)
            if available_nodes:
                return available_nodes[0].node_id
            
            return None
    
    def _is_node_suitable(self, node: Node, task: Task) -> bool:
        """Check if a node is suitable for a task"""
        if not node.is_available:
            return False
        
        # Check if node has timed out
        if time.time() - node.last_heartbeat > self.heartbeat_timeout:
            node.is_available = False
            return False
        
        # Check node type compatibility
        required_types = task.resource_requirements.get("compatible_node_types", [])
        if required_types and node.node_type.value not in required_types:
            return False
        
        # Check resource requirements
        reqs = task.resource_requirements
        caps = node.capabilities
        
        # Check memory requirement
        if reqs.get("memory_required", 0) > caps.get("memory_gb", 0):
            return False
        
        # Check GPU requirement
        if reqs.get("needs_gpu", False) and not caps.get("gpu_available", False):
            return False
        
        # Check if node has reached max concurrent tasks
        if node.current_tasks >= node.max_concurrent_tasks:
            return False
        
        # Check if node has capacity based on current load
        if node.current_load > 0.9:  # Node is over 90% loaded
            return False
        
        return True
    
    def _assign_task_to_node(self, task_id: str, node_id: str) -> bool:
        """Assign a task to a specific node"""
        with self.lock:
            if node_id not in self.nodes:
                return False
            
            node = self.nodes[node_id]
            task = self._get_task_by_id(task_id)
            
            if not task:
                return False
            
            # Update node statistics
            node.current_tasks += 1
            node.tasks_queued += 1
            
            # Update assigned tasks
            self.assigned_tasks[task_id] = node_id
            task.assigned_node = node_id
            task.assigned_at = time.time()
            
            # Update node load (estimated based on task duration)
            estimated_load = min(0.2, task.estimated_duration / 3600.0)  # Cap at 20% for long tasks
            node.current_load = min(1.0, node.current_load + estimated_load)
            
            return True
    
    def _get_task_by_id(self, task_id: str) -> Optional[Task]:
        """Get a task by ID from the queue"""
        # Find in priority queue
        for _, _, task in self.task_queue:
            if task.task_id == task_id:
                return task
        return None
    
    def _reassign_node_tasks(self, node_id: str):
        """Reassign tasks from a failed node"""
        tasks_to_reassign = []
        
        # Find tasks assigned to this node
        for task_id, assigned_node_id in self.assigned_tasks.items():
            if assigned_node_id == node_id:
                tasks_to_reassign.append(task_id)
        
        # Try to reassign each task
        for task_id in tasks_to_reassign:
            task = self._get_task_by_id(task_id)
            if task:
                # Put task back in queue for reassignment
                self.submit_task(task)
            if task_id in self.assigned_tasks:
                del self.assigned_tasks[task_id]
    
    def _attempt_task_assignments(self):
        """Try to assign queued tasks to available nodes"""
        with self.lock:
            # Make a copy of the queue to iterate without modification issues
            tasks_to_retry = []
            
            while self.task_queue:
                priority, creation_time, task = heapq.heappop(self.task_queue)
                
                # Check if task is expired
                if time.time() - task.created_at > self.task_timeout:
                    continue  # Skip expired tasks
                
                # Try to assign the task
                node_id = self._find_suitable_node(task)
                if node_id:
                    if self._assign_task_to_node(task.task_id, node_id):
                        # Successfully assigned, don't add back to queue
                        continue
                    else:
                        # Assignment failed, add back to retry list
                        tasks_to_retry.append((priority, creation_time, task))
                else:
                    # No suitable node found, add back to retry list
                    tasks_to_retry.append((priority, creation_time, task))
            
            # Put unassigned tasks back in the queue
            for item in tasks_to_retry:
                heapq.heappush(self.task_queue, item)
    
    def _update_node_performance_score(self, node_id: str):
        """Update the performance score for a node based on its stats"""
        if node_id not in self.nodes or node_id not in self.node_stats:
            return
        
        node = self.nodes[node_id]
        stats = self.node_stats[node_id]
        
        # Calculate performance score based on multiple factors
        total_tasks = node.tasks_completed + node.tasks_failed
        success_rate = stats["success_rate"]
        
        # Base score on success rate (60%), response time (25%), and load (15%)
        success_weight = 0.6
        response_weight = 0.25
        load_weight = 0.15
        
        # Success rate contribution (0.0 to 1.0)
        success_score = success_rate
        
        # Response time contribution (better response = higher score)
        avg_duration = stats["avg_task_duration"]
        response_score = 1.0 / (1.0 + avg_duration / 100.0)  # Normalize
        
        # Load contribution (avoid overloading high-performing nodes)
        load_score = 1.0 - min(1.0, node.current_load)
        
        # Calculate final score
        performance_score = (
            success_score * success_weight +
            response_score * response_weight +
            load_score * load_weight
        )
        
        node.performance_score = min(1.0, max(0.0, performance_score))
    
    def get_node_loads(self) -> Dict[str, float]:
        """Get current load for each node"""
        with self.lock:
            return {node_id: node.current_load for node_id, node in self.nodes.items()}
    
    def get_node_status(self) -> List[Dict[str, Any]]:
        """Get comprehensive status of all nodes"""
        with self.lock:
            status_list = []
            for node_id, node in self.nodes.items():
                # Check if node is still active
                is_active = time.time() - node.last_heartbeat < self.heartbeat_timeout
                node.is_available = is_active
                
                status_list.append({
                    "node_id": node.node_id,
                    "node_type": node.node_type.value,
                    "is_available": is_active,
                    "current_load": node.current_load,
                    "current_tasks": node.current_tasks,
                    "tasks_queued": node.tasks_queued,
                    "tasks_completed": node.tasks_completed,
                    "tasks_failed": node.tasks_failed,
                    "performance_score": node.performance_score,
                    "max_concurrent_tasks": node.max_concurrent_tasks,
                    "capabilities": node.capabilities,
                    "last_heartbeat": node.last_heartbeat
                })
            
            return status_list
    
    def get_task_queue_status(self) -> Dict[str, Any]:
        """Get status of the task queue"""
        with self.lock:
            return {
                "total_queued_tasks": len(self.task_queue),
                "priority_distribution": {
                    "critical": len([t for _, _, t in self.task_queue if t.priority == TaskPriority.CRITICAL]),
                    "high": len([t for _, _, t in self.task_queue if t.priority == TaskPriority.HIGH]),
                    "normal": len([t for _, _, t in self.task_queue if t.priority == TaskPriority.NORMAL]),
                    "low": len([t for _, _, t in self.task_queue if t.priority == TaskPriority.LOW])
                },
                "average_wait_time": self._calculate_avg_wait_time()
            }
    
    def _calculate_avg_wait_time(self) -> float:
        """Calculate average wait time for tasks in queue"""
        if not self.task_queue:
            return 0
        
        current_time = time.time()
        total_wait = sum(current_time - task.created_at for _, _, task in self.task_queue)
        return total_wait / len(self.task_queue) if self.task_queue else 0


# Global instance
load_balancer = LoadBalancer()