open-navigator / databricks /deployment.py
jcbowyer's picture
Clean HuggingFace deployment without binary files
61d29fc
"""
Deployment scripts for Databricks Agent Bricks.
Handles:
- Unity Catalog registration
- Model Serving deployment
- Endpoint management
- Version promotion
"""
import os
from typing import Optional, List
import mlflow
from mlflow.tracking import MlflowClient
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
ServedEntityInput,
EndpointCoreConfigInput,
TrafficConfig,
Route
)
from databricks.sdk.service.catalog import (
ModelVersionInfoStatus
)
from loguru import logger
from config import settings
class AgentDeploymentManager:
"""
Manages deployment of agents to Databricks Model Serving.
Workflow:
1. Register agent to Unity Catalog
2. Create/update serving endpoint
3. Monitor deployment status
4. Promote versions (dev -> staging -> prod)
"""
def __init__(self):
"""Initialize deployment manager."""
self.client = MlflowClient()
self.w = WorkspaceClient(
host=settings.databricks_host,
token=settings.databricks_token
)
self.catalog = settings.catalog_name
self.schema = settings.schema_name
def register_agent(
self,
agent_class,
agent_name: str,
description: str,
tags: Optional[dict] = None
) -> str:
"""
Register an agent to Unity Catalog.
Args:
agent_class: Agent class to instantiate
agent_name: Name for the registered model (e.g., "policy_classifier")
description: Model description
tags: Optional tags for the model
Returns:
Model version string
"""
full_model_name = f"{self.catalog}.{self.schema}.{agent_name}"
logger.info(f"Registering agent: {full_model_name}")
# Instantiate agent
agent = agent_class()
# Log to MLflow and register
with mlflow.start_run(run_name=f"register_{agent_name}") as run:
# Log agent metadata
mlflow.log_param("agent_name", agent_name)
mlflow.log_param("catalog", self.catalog)
mlflow.log_param("schema", self.schema)
if tags:
mlflow.set_tags(tags)
# Get example for signature
example_input = agent._get_example_input()
example_output = agent.predict(None, example_input)
signature = mlflow.models.infer_signature(example_input, example_output)
# Log model
model_info = mlflow.pyfunc.log_model(
artifact_path="agent",
python_model=agent,
signature=signature,
registered_model_name=full_model_name,
pip_requirements=self._get_requirements(agent_class)
)
run_id = run.info.run_id
# Get version number
latest_version = self.client.get_latest_versions(full_model_name)[0]
version = latest_version.version
# Update model description
if description:
self.client.update_registered_model(
name=full_model_name,
description=description
)
logger.info(f"✅ Registered {full_model_name} version {version}")
return version
def deploy_agent(
self,
agent_name: str,
endpoint_name: str,
version: Optional[str] = None,
workload_size: str = "Small",
scale_to_zero: bool = True,
min_replicas: int = 1,
max_replicas: int = 10
) -> str:
"""
Deploy agent to Model Serving endpoint.
Args:
agent_name: Registered model name
endpoint_name: Serving endpoint name
version: Model version (defaults to latest)
workload_size: Endpoint size (Small, Medium, Large)
scale_to_zero: Enable scale-to-zero
min_replicas: Minimum replicas
max_replicas: Maximum replicas
Returns:
Endpoint URL
"""
full_model_name = f"{self.catalog}.{self.schema}.{agent_name}"
# Get version if not specified
if version is None:
latest_version = self.client.get_latest_versions(full_model_name)[0]
version = latest_version.version
logger.info(f"Deploying {full_model_name} v{version} to {endpoint_name}")
# Configure served entity
served_entity = ServedEntityInput(
entity_name=full_model_name,
entity_version=version,
workload_size=workload_size,
scale_to_zero_enabled=scale_to_zero,
min_replicas=min_replicas if not scale_to_zero else 0,
max_replicas=max_replicas
)
# Create or update endpoint
try:
endpoint = self.w.serving_endpoints.create_and_wait(
name=endpoint_name,
config=EndpointCoreConfigInput(
served_entities=[served_entity]
)
)
logger.info(f"✅ Created endpoint: {endpoint_name}")
except Exception as e:
if "already exists" in str(e).lower():
# Update existing endpoint
endpoint = self.w.serving_endpoints.update_config_and_wait(
name=endpoint_name,
served_entities=[served_entity]
)
logger.info(f"✅ Updated endpoint: {endpoint_name}")
else:
raise
# Return invocation URL
endpoint_url = f"{settings.databricks_host}/serving-endpoints/{endpoint_name}/invocations"
logger.info(f" Endpoint URL: {endpoint_url}")
return endpoint_url
def create_multi_agent_endpoint(
self,
endpoint_name: str,
agents: List[tuple], # [(agent_name, version, traffic_percentage)]
workload_size: str = "Medium"
) -> str:
"""
Create endpoint serving multiple agents with traffic splitting.
Args:
endpoint_name: Endpoint name
agents: List of (agent_name, version, traffic_percentage) tuples
workload_size: Endpoint size
Returns:
Endpoint URL
"""
logger.info(f"Creating multi-agent endpoint: {endpoint_name}")
# Build served entities
served_entities = []
for agent_name, version, traffic_pct in agents:
full_model_name = f"{self.catalog}.{self.schema}.{agent_name}"
served_entities.append(
ServedEntityInput(
entity_name=full_model_name,
entity_version=version,
workload_size=workload_size,
scale_to_zero_enabled=True
)
)
# Create endpoint
endpoint = self.w.serving_endpoints.create_and_wait(
name=endpoint_name,
config=EndpointCoreConfigInput(
served_entities=served_entities
)
)
logger.info(f"✅ Created multi-agent endpoint with {len(agents)} agents")
return f"{settings.databricks_host}/serving-endpoints/{endpoint_name}/invocations"
def test_endpoint(self, endpoint_name: str, test_input: dict) -> dict:
"""
Test a deployed endpoint.
Args:
endpoint_name: Endpoint name
test_input: Test input data
Returns:
Prediction result
"""
import requests
url = f"{settings.databricks_host}/serving-endpoints/{endpoint_name}/invocations"
headers = {
"Authorization": f"Bearer {settings.databricks_token}",
"Content-Type": "application/json"
}
response = requests.post(
url,
headers=headers,
json={"dataframe_records": [test_input]}
)
response.raise_for_status()
return response.json()
def get_endpoint_status(self, endpoint_name: str) -> dict:
"""Get endpoint status and metrics."""
endpoint = self.w.serving_endpoints.get(name=endpoint_name)
return {
"name": endpoint.name,
"state": endpoint.state.config_update if endpoint.state else "Unknown",
"served_entities": [
{
"name": entity.name,
"version": entity.entity_version,
"state": entity.state
}
for entity in (endpoint.config.served_entities or [])
]
}
def _get_requirements(self, agent_class) -> List[str]:
"""Get pip requirements for an agent."""
return [
"mlflow>=2.10.0",
"databricks-agents>=0.1.0",
"langchain>=0.1.0",
"openai>=1.6.0",
"anthropic>=0.8.0",
"pydantic>=2.5.0",
"loguru>=0.7.0"
]
def deploy_all_agents():
"""
Deploy all agents to Databricks Model Serving.
Usage:
python -m databricks.deployment
"""
from agents.mlflow_classifier import PolicyClassifierAgent
manager = AgentDeploymentManager()
# Register and deploy classifier
print("\n📦 Deploying Policy Classifier Agent...")
version = manager.register_agent(
agent_class=PolicyClassifierAgent,
agent_name="policy_classifier",
description="Classifies government meeting documents for oral health policy topics",
tags={"team": "advocacy", "domain": "oral_health"}
)
endpoint_url = manager.deploy_agent(
agent_name="policy_classifier",
endpoint_name="policy-classifier-prod",
version=version,
workload_size="Small",
scale_to_zero=True
)
print(f"\n✅ Deployment Complete!")
print(f" Endpoint: {endpoint_url}")
print(f"\n🧪 Test with:")
print(f"""
curl -X POST {endpoint_url} \\
-H "Authorization: Bearer $DATABRICKS_TOKEN" \\
-H "Content-Type: application/json" \\
-d '{{"dataframe_records": [{{"document_id": "test", "title": "Meeting", "content": "Fluoride discussion..."}}]}}'
""")
# Test endpoint
print("\n🧪 Testing endpoint...")
result = manager.test_endpoint(
endpoint_name="policy-classifier-prod",
test_input={
"document_id": "test_001",
"title": "City Council Meeting",
"content": "Discussion on water fluoridation program"
}
)
print(f" Result: {result}")
if __name__ == "__main__":
deploy_all_agents()