""" Databricks-Compatible MLflow Agent — Data Engineering Knowledge Assistant ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ • Structured as an MLflow PyFunc model so it can be logged + served on Databricks • Uses Groq (llama-3.1-8b-instant) for ultra-low-latency responses • Streaming path: direct RAG (retrieve → stuff → stream) — simple, reliable • Sync path: tool-calling agent (search, code_gen) for richer Databricks demos """ from __future__ import annotations import os import json from typing import AsyncIterator, List, Dict, Optional from rag import DataEngineeringRAG # ────────────────────────────────────────────────────────────────────────────── # System prompt # ────────────────────────────────────────────────────────────────────────────── SYSTEM_PROMPT = """You are an elite Data Engineering Knowledge Assistant, \ specialising in production-grade data pipelines, architecture patterns, and Databricks. Your knowledge comes from "Data Engineering Design Patterns" — a comprehensive guide \ to solving real data engineering problems. Guidelines: 1. Ground every answer in the retrieved context provided below. 2. Give concrete, code-inclusive answers when relevant (PySpark / Python / SQL). 3. Reference specific patterns by name (Lambda, Kappa, Medallion, Lakehouse, CDC, etc.). 4. Be direct and technical — the user is a practising data engineer. 5. If the retrieved context doesn't cover the question, say so — never fabricate. Format: - Direct answer first - Code blocks with ```python or ```sql - Pattern names in **bold** - End with a "💡 Pro tip:" line when you have a non-obvious insight """ # ────────────────────────────────────────────────────────────────────────────── # Tool schemas (used by sync invoke() for the Databricks demo path) # ────────────────────────────────────────────────────────────────────────────── TOOLS = [ { "type": "function", "function": { "name": "search_knowledge_base", "description": "Retrieve relevant chunks from the Data Engineering Design Patterns book.", "parameters": { "type": "object", "properties": { "query": {"type": "string"}, "k": {"type": "integer", "default": 5}, }, "required": ["query"], }, }, } ] # ────────────────────────────────────────────────────────────────────────────── # Agent # ────────────────────────────────────────────────────────────────────────────── class DataEngineeringAgent: def __init__(self, rag: DataEngineeringRAG, groq_api_key: str): self.rag = rag self.groq_api_key = groq_api_key self._sync_client = None self._async_client = None # ── Groq clients (lazy init) ────────────────────────────────────────────── def _get_sync_client(self): if self._sync_client is None: from groq import Groq self._sync_client = Groq(api_key=self.groq_api_key) return self._sync_client def _get_async_client(self): if self._async_client is None: from groq import AsyncGroq self._async_client = AsyncGroq(api_key=self.groq_api_key) return self._async_client # ── Context builder ─────────────────────────────────────────────────────── # PDF extractors often emit these invisible / structural Unicode chars. # In containers with an ASCII-only default locale (common on minimal Docker # images), the HTTP client can fail with `UnicodeEncodeError: 'ascii' codec` # when serialising them. Strip them at the source. _UNICODE_SCRUB = str.maketrans({ "\u2028": "\n", # LINE SEPARATOR "\u2029": "\n\n", # PARAGRAPH SEPARATOR "\u200b": "", # ZERO WIDTH SPACE "\u200c": "", # ZERO WIDTH NON-JOINER "\u200d": "", # ZERO WIDTH JOINER "\ufeff": "", # BYTE ORDER MARK "\x00": "", # NULL "\xa0": " ", # NON-BREAKING SPACE }) @classmethod def _sanitize(cls, text: str) -> str: return (text or "").translate(cls._UNICODE_SCRUB) def _build_context(self, query: str, k: int = 5) -> str: """Retrieve top-k chunks and format as prompt context.""" chunks = self.rag.search(query, k=k) if not chunks: return "(No relevant context found in the knowledge base.)" formatted = [] for i, c in enumerate(chunks, 1): formatted.append( f"[Source {i} · Page {c['page']} · Relevance {c['score']:.2f}]\n" f"{self._sanitize(c['content'])}" ) return "\n\n---\n\n".join(formatted) def _build_messages( self, user_message: str, history: List[Dict], inject_context: bool = True ) -> List[Dict]: """Build the chat-completions messages array.""" system = SYSTEM_PROMPT if inject_context: context = self._build_context(user_message, k=5) system += f"\n\n━━━ RETRIEVED CONTEXT ━━━\n{context}\n━━━━━━━━━━━━━━━━━━━━━━━━" messages = [{"role": "system", "content": system}] # Keep last 3 exchanges (6 messages) for continuity for turn in history[-6:]: messages.append({"role": turn["role"], "content": turn["content"]}) messages.append({"role": "user", "content": user_message}) return messages # ── Async streaming (used by the FastAPI /api/chat endpoint) ────────────── async def astream( self, message: str, history: Optional[List[Dict]] = None ) -> AsyncIterator[str]: """ Streaming RAG response. Yields string chunks as the model generates. First-token latency on Groq free tier: ~150-300 ms. """ client = self._get_async_client() messages = self._build_messages(message, history or [], inject_context=True) try: stream = await client.chat.completions.create( model="llama-3.1-8b-instant", messages=messages, temperature=0.3, max_tokens=2048, stream=True, ) async for chunk in stream: delta = chunk.choices[0].delta.content if delta: yield delta except Exception as exc: # Expose the real error to the client so debugging is easy yield f"\n\n⚠️ **Agent error:** `{type(exc).__name__}: {exc}`\n\n" yield "Common causes: missing or invalid GROQ_API_KEY, Groq rate limit hit, network issue." # ── Sync invoke with tool use (Databricks / MLflow path) ────────────────── def invoke(self, message: str, history: Optional[List[Dict]] = None) -> str: """Single-turn synchronous call — used by the MLflow PyFunc wrapper.""" client = self._get_sync_client() messages = self._build_messages(message, history or [], inject_context=False) # Let the model decide if it wants to search response = client.chat.completions.create( model="llama-3.1-8b-instant", messages=messages, tools=TOOLS, tool_choice="auto", temperature=0.2, max_tokens=2048, ) msg = response.choices[0].message # Tool-resolution loop (max 3 iterations to prevent infinite cycles) for _ in range(3): if not msg.tool_calls: break messages.append(msg) for tc in msg.tool_calls: args = json.loads(tc.function.arguments) if tc.function.name == "search_knowledge_base": tool_result = self._build_context(args["query"], args.get("k", 5)) else: tool_result = f"Unknown tool: {tc.function.name}" messages.append( {"role": "tool", "tool_call_id": tc.id, "content": tool_result} ) response = client.chat.completions.create( model="llama-3.1-8b-instant", messages=messages, tools=TOOLS, tool_choice="auto", temperature=0.2, max_tokens=2048, ) msg = response.choices[0].message return msg.content or "(no content generated)" # ── MLflow PyFunc interface ─────────────────────────────────────────────── def predict(self, context, model_input) -> str: import pandas as pd if isinstance(model_input, pd.DataFrame): row = model_input.iloc[0] message = row.get("message", "") history = row.get("history", []) if isinstance(history, str): history = json.loads(history) else: message = model_input.get("message", "") history = model_input.get("history", []) return self.invoke(message=message, history=history) # ────────────────────────────────────────────────────────────────────────────── # MLflow wrapper (for Databricks Model Serving registration) # ────────────────────────────────────────────────────────────────────────────── class DEAgentPyFunc: def load_context(self, context): pdf_path = context.artifacts.get( "pdf_path", "knowledge/data_engineering_patterns.pdf" ) groq_key = os.environ.get("GROQ_API_KEY", "") self.rag = DataEngineeringRAG(pdf_path=pdf_path, groq_api_key=groq_key) self.rag.initialize() self.agent = DataEngineeringAgent(rag=self.rag, groq_api_key=groq_key) def predict(self, context, model_input): return self.agent.predict(context, model_input)