""" MLflow-based agent foundation for Databricks Agent Bricks. Provides: - MLflow Pyfunc model wrappers for agents - Unity Catalog integration - Automatic tracing and observability - Model serving compatibility """ from typing import Any, Dict, List, Optional, Union from abc import ABC, abstractmethod import mlflow from mlflow.pyfunc import PythonModel from mlflow.models import infer_signature from mlflow.tracking import MlflowClient import pandas as pd from datetime import datetime from loguru import logger from agents.base import AgentRole, AgentMessage, AgentStatus from config import settings class MLflowAgentBase(PythonModel, ABC): """ Base class for agents that can be deployed via MLflow Model Serving. Integrates with: - Unity Catalog for governance - MLflow Tracking for experimentation - Databricks Model Serving for deployment - Mosaic AI Agent Framework for evaluation """ def __init__(self, agent_id: str, role: AgentRole): """ Initialize MLflow agent. Args: agent_id: Unique identifier for this agent role: Agent role in the pipeline """ super().__init__() self.agent_id = agent_id self.role = role self.status = AgentStatus.IDLE self.client = MlflowClient() def load_context(self, context): """ Load agent context from MLflow (called during model loading). Args: context: MLflow context with model artifacts """ logger.info(f"Loading {self.role.value} agent from MLflow context") # Load any model artifacts, configs, etc. pass @abstractmethod def _process_request(self, request: Dict[str, Any]) -> Dict[str, Any]: """ Process a single agent request. Args: request: Input request dictionary Returns: Response dictionary """ pass def predict( self, context, model_input: Union[pd.DataFrame, Dict[str, Any], List[Dict[str, Any]]] ) -> Union[pd.DataFrame, List[Dict[str, Any]]]: """ MLflow Pyfunc predict interface. This is the main entry point when the agent is deployed as a Model Serving endpoint. Args: context: MLflow context model_input: Input data (DataFrame, dict, or list of dicts) Returns: Predictions in same format as input """ # Enable MLflow tracing for observability with mlflow.start_span(name=f"{self.role.value}_agent") as span: span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_role", self.role.value) try: # Convert input to standard format if isinstance(model_input, pd.DataFrame): requests = model_input.to_dict('records') return_df = True elif isinstance(model_input, dict): requests = [model_input] return_df = False else: requests = model_input return_df = False # Process each request with tracing results = [] for idx, request in enumerate(requests): with mlflow.start_span(name=f"process_request_{idx}") as req_span: req_span.set_attribute("request_id", request.get("request_id", f"req_{idx}")) try: result = self._process_request(request) result["status"] = "success" result["agent_id"] = self.agent_id result["timestamp"] = datetime.utcnow().isoformat() results.append(result) req_span.set_attribute("status", "success") except Exception as e: error_result = { "status": "error", "error": str(e), "agent_id": self.agent_id, "timestamp": datetime.utcnow().isoformat() } results.append(error_result) req_span.set_attribute("status", "error") req_span.set_attribute("error", str(e)) logger.error(f"Error processing request {idx}: {e}") # Return in requested format if return_df: return pd.DataFrame(results) elif len(results) == 1 and not isinstance(model_input, list): return results[0] else: return results except Exception as e: span.set_attribute("status", "error") span.set_attribute("error", str(e)) logger.error(f"Error in {self.role.value} agent: {e}") raise def log_to_mlflow( self, model_name: str, artifact_path: str = "agent", registered_model_name: Optional[str] = None, **kwargs ): """ Log this agent to MLflow. Args: model_name: Name for the MLflow run artifact_path: Path within the run to store the model registered_model_name: Unity Catalog model name (e.g., "main.agents.scraper") **kwargs: Additional MLflow logging parameters """ with mlflow.start_run(run_name=model_name) as run: # Log agent metadata mlflow.log_param("agent_id", self.agent_id) mlflow.log_param("agent_role", self.role.value) mlflow.log_param("framework", "databricks-agent-bricks") # Create example input/output for signature example_input = self._get_example_input() example_output = self.predict(None, example_input) signature = infer_signature(example_input, example_output) # Log the model mlflow.pyfunc.log_model( artifact_path=artifact_path, python_model=self, signature=signature, registered_model_name=registered_model_name, **kwargs ) logger.info(f"Logged {self.role.value} agent to MLflow run {run.info.run_id}") if registered_model_name: logger.info(f"Registered model as {registered_model_name}") return run.info.run_id @abstractmethod def _get_example_input(self) -> Union[pd.DataFrame, Dict[str, Any]]: """ Get example input for MLflow signature inference. Returns: Example input data """ pass def deploy_to_model_serving( self, model_name: str, endpoint_name: str, workload_size: str = "Small", scale_to_zero: bool = True ) -> str: """ Deploy this agent to Databricks Model Serving. Args: model_name: Registered model name in Unity Catalog endpoint_name: Name for the serving endpoint workload_size: Endpoint size (Small, Medium, Large) scale_to_zero: Whether to scale to zero when idle Returns: Endpoint URL """ from databricks.sdk import WorkspaceClient from databricks.sdk.service.serving import ServedEntityInput, EndpointCoreConfigInput w = WorkspaceClient( host=settings.databricks_host, token=settings.databricks_token ) # Get latest model version latest_version = self.client.get_latest_versions(model_name, stages=["None"])[0].version # Create or update endpoint endpoint_config = EndpointCoreConfigInput( name=endpoint_name, served_entities=[ ServedEntityInput( entity_name=model_name, entity_version=latest_version, workload_size=workload_size, scale_to_zero_enabled=scale_to_zero ) ] ) try: endpoint = w.serving_endpoints.create_and_wait( name=endpoint_name, config=endpoint_config ) logger.info(f"Created endpoint: {endpoint_name}") except Exception as e: if "already exists" in str(e): endpoint = w.serving_endpoints.update_config_and_wait( name=endpoint_name, served_entities=endpoint_config.served_entities ) logger.info(f"Updated endpoint: {endpoint_name}") else: raise endpoint_url = f"{settings.databricks_host}/serving-endpoints/{endpoint_name}/invocations" return endpoint_url class MLflowChainAgent(MLflowAgentBase): """ Agent that uses LangChain with MLflow tracing. Provides integration with: - LangChain agents and chains - Automatic prompt logging - LLM call tracing - Tool usage tracking """ def __init__(self, agent_id: str, role: AgentRole): """Initialize LangChain-based agent.""" super().__init__(agent_id, role) self.chain = None def _setup_langchain_tracing(self): """Enable MLflow tracing for LangChain.""" mlflow.langchain.autolog() @abstractmethod def _build_chain(self): """ Build the LangChain chain for this agent. Returns: LangChain chain or agent """ pass def _process_request(self, request: Dict[str, Any]) -> Dict[str, Any]: """Process request through LangChain.""" if self.chain is None: self.chain = self._build_chain() with mlflow.start_span(name="langchain_invoke") as span: result = self.chain.invoke(request) # Log relevant metrics if hasattr(result, "llm_output"): span.set_attribute("tokens_used", result.llm_output.get("token_usage", {}).get("total_tokens", 0)) return result