File size: 11,018 Bytes
61d29fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
"""
Deployment scripts for Databricks Agent Bricks.

Handles:
- Unity Catalog registration
- Model Serving deployment
- Endpoint management
- Version promotion
"""
import os
from typing import Optional, List
import mlflow
from mlflow.tracking import MlflowClient
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    ServedEntityInput,
    EndpointCoreConfigInput,
    TrafficConfig,
    Route
)
from databricks.sdk.service.catalog import (
    ModelVersionInfoStatus
)
from loguru import logger

from config import settings


class AgentDeploymentManager:
    """
    Manages deployment of agents to Databricks Model Serving.
    
    Workflow:
    1. Register agent to Unity Catalog
    2. Create/update serving endpoint
    3. Monitor deployment status
    4. Promote versions (dev -> staging -> prod)
    """
    
    def __init__(self):
        """Initialize deployment manager."""
        self.client = MlflowClient()
        self.w = WorkspaceClient(
            host=settings.databricks_host,
            token=settings.databricks_token
        )
        self.catalog = settings.catalog_name
        self.schema = settings.schema_name
        
    def register_agent(
        self,
        agent_class,
        agent_name: str,
        description: str,
        tags: Optional[dict] = None
    ) -> str:
        """
        Register an agent to Unity Catalog.
        
        Args:
            agent_class: Agent class to instantiate
            agent_name: Name for the registered model (e.g., "policy_classifier")
            description: Model description
            tags: Optional tags for the model
            
        Returns:
            Model version string
        """
        full_model_name = f"{self.catalog}.{self.schema}.{agent_name}"
        
        logger.info(f"Registering agent: {full_model_name}")
        
        # Instantiate agent
        agent = agent_class()
        
        # Log to MLflow and register
        with mlflow.start_run(run_name=f"register_{agent_name}") as run:
            # Log agent metadata
            mlflow.log_param("agent_name", agent_name)
            mlflow.log_param("catalog", self.catalog)
            mlflow.log_param("schema", self.schema)
            
            if tags:
                mlflow.set_tags(tags)
            
            # Get example for signature
            example_input = agent._get_example_input()
            example_output = agent.predict(None, example_input)
            signature = mlflow.models.infer_signature(example_input, example_output)
            
            # Log model
            model_info = mlflow.pyfunc.log_model(
                artifact_path="agent",
                python_model=agent,
                signature=signature,
                registered_model_name=full_model_name,
                pip_requirements=self._get_requirements(agent_class)
            )
            
            run_id = run.info.run_id
        
        # Get version number
        latest_version = self.client.get_latest_versions(full_model_name)[0]
        version = latest_version.version
        
        # Update model description
        if description:
            self.client.update_registered_model(
                name=full_model_name,
                description=description
            )
        
        logger.info(f"✅ Registered {full_model_name} version {version}")
        
        return version
    
    def deploy_agent(
        self,
        agent_name: str,
        endpoint_name: str,
        version: Optional[str] = None,
        workload_size: str = "Small",
        scale_to_zero: bool = True,
        min_replicas: int = 1,
        max_replicas: int = 10
    ) -> str:
        """
        Deploy agent to Model Serving endpoint.
        
        Args:
            agent_name: Registered model name
            endpoint_name: Serving endpoint name
            version: Model version (defaults to latest)
            workload_size: Endpoint size (Small, Medium, Large)
            scale_to_zero: Enable scale-to-zero
            min_replicas: Minimum replicas
            max_replicas: Maximum replicas
            
        Returns:
            Endpoint URL
        """
        full_model_name = f"{self.catalog}.{self.schema}.{agent_name}"
        
        # Get version if not specified
        if version is None:
            latest_version = self.client.get_latest_versions(full_model_name)[0]
            version = latest_version.version
        
        logger.info(f"Deploying {full_model_name} v{version} to {endpoint_name}")
        
        # Configure served entity
        served_entity = ServedEntityInput(
            entity_name=full_model_name,
            entity_version=version,
            workload_size=workload_size,
            scale_to_zero_enabled=scale_to_zero,
            min_replicas=min_replicas if not scale_to_zero else 0,
            max_replicas=max_replicas
        )
        
        # Create or update endpoint
        try:
            endpoint = self.w.serving_endpoints.create_and_wait(
                name=endpoint_name,
                config=EndpointCoreConfigInput(
                    served_entities=[served_entity]
                )
            )
            logger.info(f"✅ Created endpoint: {endpoint_name}")
            
        except Exception as e:
            if "already exists" in str(e).lower():
                # Update existing endpoint
                endpoint = self.w.serving_endpoints.update_config_and_wait(
                    name=endpoint_name,
                    served_entities=[served_entity]
                )
                logger.info(f"✅ Updated endpoint: {endpoint_name}")
            else:
                raise
        
        # Return invocation URL
        endpoint_url = f"{settings.databricks_host}/serving-endpoints/{endpoint_name}/invocations"
        logger.info(f"   Endpoint URL: {endpoint_url}")
        
        return endpoint_url
    
    def create_multi_agent_endpoint(
        self,
        endpoint_name: str,
        agents: List[tuple],  # [(agent_name, version, traffic_percentage)]
        workload_size: str = "Medium"
    ) -> str:
        """
        Create endpoint serving multiple agents with traffic splitting.
        
        Args:
            endpoint_name: Endpoint name
            agents: List of (agent_name, version, traffic_percentage) tuples
            workload_size: Endpoint size
            
        Returns:
            Endpoint URL
        """
        logger.info(f"Creating multi-agent endpoint: {endpoint_name}")
        
        # Build served entities
        served_entities = []
        for agent_name, version, traffic_pct in agents:
            full_model_name = f"{self.catalog}.{self.schema}.{agent_name}"
            
            served_entities.append(
                ServedEntityInput(
                    entity_name=full_model_name,
                    entity_version=version,
                    workload_size=workload_size,
                    scale_to_zero_enabled=True
                )
            )
        
        # Create endpoint
        endpoint = self.w.serving_endpoints.create_and_wait(
            name=endpoint_name,
            config=EndpointCoreConfigInput(
                served_entities=served_entities
            )
        )
        
        logger.info(f"✅ Created multi-agent endpoint with {len(agents)} agents")
        
        return f"{settings.databricks_host}/serving-endpoints/{endpoint_name}/invocations"
    
    def test_endpoint(self, endpoint_name: str, test_input: dict) -> dict:
        """
        Test a deployed endpoint.
        
        Args:
            endpoint_name: Endpoint name
            test_input: Test input data
            
        Returns:
            Prediction result
        """
        import requests
        
        url = f"{settings.databricks_host}/serving-endpoints/{endpoint_name}/invocations"
        
        headers = {
            "Authorization": f"Bearer {settings.databricks_token}",
            "Content-Type": "application/json"
        }
        
        response = requests.post(
            url,
            headers=headers,
            json={"dataframe_records": [test_input]}
        )
        
        response.raise_for_status()
        return response.json()
    
    def get_endpoint_status(self, endpoint_name: str) -> dict:
        """Get endpoint status and metrics."""
        endpoint = self.w.serving_endpoints.get(name=endpoint_name)
        
        return {
            "name": endpoint.name,
            "state": endpoint.state.config_update if endpoint.state else "Unknown",
            "served_entities": [
                {
                    "name": entity.name,
                    "version": entity.entity_version,
                    "state": entity.state
                }
                for entity in (endpoint.config.served_entities or [])
            ]
        }
    
    def _get_requirements(self, agent_class) -> List[str]:
        """Get pip requirements for an agent."""
        return [
            "mlflow>=2.10.0",
            "databricks-agents>=0.1.0",
            "langchain>=0.1.0",
            "openai>=1.6.0",
            "anthropic>=0.8.0",
            "pydantic>=2.5.0",
            "loguru>=0.7.0"
        ]


def deploy_all_agents():
    """
    Deploy all agents to Databricks Model Serving.
    
    Usage:
        python -m databricks.deployment
    """
    from agents.mlflow_classifier import PolicyClassifierAgent
    
    manager = AgentDeploymentManager()
    
    # Register and deploy classifier
    print("\n📦 Deploying Policy Classifier Agent...")
    version = manager.register_agent(
        agent_class=PolicyClassifierAgent,
        agent_name="policy_classifier",
        description="Classifies government meeting documents for oral health policy topics",
        tags={"team": "advocacy", "domain": "oral_health"}
    )
    
    endpoint_url = manager.deploy_agent(
        agent_name="policy_classifier",
        endpoint_name="policy-classifier-prod",
        version=version,
        workload_size="Small",
        scale_to_zero=True
    )
    
    print(f"\n✅ Deployment Complete!")
    print(f"   Endpoint: {endpoint_url}")
    print(f"\n🧪 Test with:")
    print(f"""
    curl -X POST {endpoint_url} \\
      -H "Authorization: Bearer $DATABRICKS_TOKEN" \\
      -H "Content-Type: application/json" \\
      -d '{{"dataframe_records": [{{"document_id": "test", "title": "Meeting", "content": "Fluoride discussion..."}}]}}'
    """)
    
    # Test endpoint
    print("\n🧪 Testing endpoint...")
    result = manager.test_endpoint(
        endpoint_name="policy-classifier-prod",
        test_input={
            "document_id": "test_001",
            "title": "City Council Meeting",
            "content": "Discussion on water fluoridation program"
        }
    )
    print(f"   Result: {result}")


if __name__ == "__main__":
    deploy_all_agents()