File size: 6,247 Bytes
f2bab5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""

Scaling manager for horizontal scaling of cloud agents.

"""
import ray
import asyncio
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
from .couchdb_client import CouchDBClient
from .agent import Agent
from .config import settings

logger = logging.getLogger(__name__)

class ScalingManager:
    """Manager for horizontal scaling of cloud agents."""
    
    def __init__(self):
        self.db_client = CouchDBClient()
        self._initialize_ray()
        self.min_agents = 2
        self.max_agents = 10
        self.scale_up_threshold = 0.8  # Scale up when 80% of agents are busy
        self.scale_down_threshold = 0.3  # Scale down when less than 30% of agents are busy
        self.agent_refs: Dict[str, ray.actor.ActorHandle] = {}
    
    def _initialize_ray(self):
        """Initialize Ray for distributed computing."""
        if not ray.is_initialized():
            ray.init(address=f"ray://{settings.COORDINATOR_HOST}:{settings.RAY_HEAD_PORT}")
    
    async def monitor_and_scale(self):
        """Monitor cluster health and scale as needed."""
        while True:
            try:
                await self._check_agent_health()
                await self._scale_cluster()
                await asyncio.sleep(60)  # Check every minute
            except Exception as e:
                logger.error(f"Error in monitor and scale loop: {e}")
                await asyncio.sleep(5)
    
    async def _check_agent_health(self):
        """Check health of all agents and remove dead ones."""
        try:
            active_agents = self.db_client.get_active_agents()
            current_time = datetime.utcnow()
            
            for agent in active_agents:
                last_heartbeat = datetime.fromisoformat(agent['last_heartbeat'])
                if current_time - last_heartbeat > timedelta(minutes=5):
                    # Agent is considered dead
                    logger.warning(f"Agent {agent['_id']} appears to be dead. Removing...")
                    await self._remove_agent(agent['_id'])
        except Exception as e:
            logger.error(f"Error checking agent health: {e}")
            raise
    
    async def _scale_cluster(self):
        """Scale the cluster based on workload."""
        try:
            active_agents = self.db_client.get_active_agents()
            total_agents = len(active_agents)
            busy_agents = len([a for a in active_agents if a['current_job'] is not None])
            
            if total_agents < 1:
                # Always ensure at least one agent is running
                await self._add_agent()
                return
            
            utilization = busy_agents / total_agents if total_agents > 0 else 0
            
            # Scale up if needed
            if utilization >= self.scale_up_threshold and total_agents < self.max_agents:
                num_to_add = min(2, self.max_agents - total_agents)  # Add up to 2 agents at a time
                logger.info(f"Scaling up: Adding {num_to_add} agents")
                for _ in range(num_to_add):
                    await self._add_agent()
            
            # Scale down if needed
            elif utilization <= self.scale_down_threshold and total_agents > self.min_agents:
                num_to_remove = min(1, total_agents - self.min_agents)  # Remove 1 agent at a time
                logger.info(f"Scaling down: Removing {num_to_remove} agents")
                idle_agents = [a for a in active_agents if a['current_job'] is None]
                for _ in range(num_to_remove):
                    if idle_agents:
                        await self._remove_agent(idle_agents.pop()['_id'])
        
        except Exception as e:
            logger.error(f"Error scaling cluster: {e}")
            raise
    
    async def _add_agent(self):
        """Add a new agent to the cluster."""
        try:
            # Create new agent actor using Ray
            agent_ref = ray.remote(Agent).options(
                num_cpus=1,
                num_gpus=0.5 if ray.get_gpu_ids() else 0
            ).remote()
            
            # Store reference
            agent_id = await agent_ref.get_id.remote()
            self.agent_refs[agent_id] = agent_ref
            
            # Start agent
            ray.get(agent_ref.run.remote())
            
            logger.info(f"Added new agent {agent_id}")
            return agent_id
        
        except Exception as e:
            logger.error(f"Error adding agent: {e}")
            raise
    
    async def _remove_agent(self, agent_id: str):
        """Remove an agent from the cluster."""
        try:
            # Get agent reference
            agent_ref = self.agent_refs.get(agent_id)
            if agent_ref:
                # Shutdown agent gracefully
                await agent_ref.shutdown.remote()
                # Remove from Ray
                ray.kill(agent_ref)
                # Remove from local tracking
                del self.agent_refs[agent_id]
            
            logger.info(f"Removed agent {agent_id}")
        
        except Exception as e:
            logger.error(f"Error removing agent: {e}")
            raise
    
    def get_cluster_status(self) -> Dict[str, Any]:
        """Get current status of the cluster."""
        try:
            active_agents = self.db_client.get_active_agents()
            total_agents = len(active_agents)
            busy_agents = len([a for a in active_agents if a['current_job'] is not None])
            
            return {
                'total_agents': total_agents,
                'busy_agents': busy_agents,
                'idle_agents': total_agents - busy_agents,
                'utilization': busy_agents / total_agents if total_agents > 0 else 0,
                'can_scale_up': total_agents < self.max_agents,
                'can_scale_down': total_agents > self.min_agents
            }
        
        except Exception as e:
            logger.error(f"Error getting cluster status: {e}")
            raise