Data_eng_designer / agent.py
focustiki's picture
Update agent.py
86d4a57 verified
"""
Databricks-Compatible MLflow Agent β€” Data Engineering Knowledge Assistant
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
β€’ Structured as an MLflow PyFunc model so it can be logged + served on Databricks
β€’ Uses Groq (llama-3.1-8b-instant) for ultra-low-latency responses
β€’ Streaming path: direct RAG (retrieve β†’ stuff β†’ stream) β€” simple, reliable
β€’ Sync path: tool-calling agent (search, code_gen) for richer Databricks demos
"""
from __future__ import annotations
import os
import json
from typing import AsyncIterator, List, Dict, Optional
from rag import DataEngineeringRAG
# ──────────────────────────────────────────────────────────────────────────────
# System prompt
# ──────────────────────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are an elite Data Engineering Knowledge Assistant, \
specialising in production-grade data pipelines, architecture patterns, and Databricks.
Your knowledge comes from "Data Engineering Design Patterns" β€” a comprehensive guide \
to solving real data engineering problems.
Guidelines:
1. Ground every answer in the retrieved context provided below.
2. Give concrete, code-inclusive answers when relevant (PySpark / Python / SQL).
3. Reference specific patterns by name (Lambda, Kappa, Medallion, Lakehouse, CDC, etc.).
4. Be direct and technical β€” the user is a practising data engineer.
5. If the retrieved context doesn't cover the question, say so β€” never fabricate.
Format:
- Direct answer first
- Code blocks with ```python or ```sql
- Pattern names in **bold**
- End with a "πŸ’‘ Pro tip:" line when you have a non-obvious insight
"""
# ──────────────────────────────────────────────────────────────────────────────
# Tool schemas (used by sync invoke() for the Databricks demo path)
# ──────────────────────────────────────────────────────────────────────────────
TOOLS = [
{
"type": "function",
"function": {
"name": "search_knowledge_base",
"description": "Retrieve relevant chunks from the Data Engineering Design Patterns book.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"k": {"type": "integer", "default": 5},
},
"required": ["query"],
},
},
}
]
# ──────────────────────────────────────────────────────────────────────────────
# Agent
# ──────────────────────────────────────────────────────────────────────────────
class DataEngineeringAgent:
def __init__(self, rag: DataEngineeringRAG, groq_api_key: str):
self.rag = rag
self.groq_api_key = groq_api_key
self._sync_client = None
self._async_client = None
# ── Groq clients (lazy init) ──────────────────────────────────────────────
def _get_sync_client(self):
if self._sync_client is None:
from groq import Groq
self._sync_client = Groq(api_key=self.groq_api_key)
return self._sync_client
def _get_async_client(self):
if self._async_client is None:
from groq import AsyncGroq
self._async_client = AsyncGroq(api_key=self.groq_api_key)
return self._async_client
# ── Context builder ───────────────────────────────────────────────────────
# PDF extractors often emit these invisible / structural Unicode chars.
# In containers with an ASCII-only default locale (common on minimal Docker
# images), the HTTP client can fail with `UnicodeEncodeError: 'ascii' codec`
# when serialising them. Strip them at the source.
_UNICODE_SCRUB = str.maketrans({
"\u2028": "\n", # LINE SEPARATOR
"\u2029": "\n\n", # PARAGRAPH SEPARATOR
"\u200b": "", # ZERO WIDTH SPACE
"\u200c": "", # ZERO WIDTH NON-JOINER
"\u200d": "", # ZERO WIDTH JOINER
"\ufeff": "", # BYTE ORDER MARK
"\x00": "", # NULL
"\xa0": " ", # NON-BREAKING SPACE
})
@classmethod
def _sanitize(cls, text: str) -> str:
return (text or "").translate(cls._UNICODE_SCRUB)
def _build_context(self, query: str, k: int = 5) -> str:
"""Retrieve top-k chunks and format as prompt context."""
chunks = self.rag.search(query, k=k)
if not chunks:
return "(No relevant context found in the knowledge base.)"
formatted = []
for i, c in enumerate(chunks, 1):
formatted.append(
f"[Source {i} Β· Page {c['page']} Β· Relevance {c['score']:.2f}]\n"
f"{self._sanitize(c['content'])}"
)
return "\n\n---\n\n".join(formatted)
def _build_messages(
self, user_message: str, history: List[Dict], inject_context: bool = True
) -> List[Dict]:
"""Build the chat-completions messages array."""
system = SYSTEM_PROMPT
if inject_context:
context = self._build_context(user_message, k=5)
system += f"\n\n━━━ RETRIEVED CONTEXT ━━━\n{context}\n━━━━━━━━━━━━━━━━━━━━━━━━"
messages = [{"role": "system", "content": system}]
# Keep last 3 exchanges (6 messages) for continuity
for turn in history[-6:]:
messages.append({"role": turn["role"], "content": turn["content"]})
messages.append({"role": "user", "content": user_message})
return messages
# ── Async streaming (used by the FastAPI /api/chat endpoint) ──────────────
async def astream(
self, message: str, history: Optional[List[Dict]] = None
) -> AsyncIterator[str]:
"""
Streaming RAG response. Yields string chunks as the model generates.
First-token latency on Groq free tier: ~150-300 ms.
"""
client = self._get_async_client()
messages = self._build_messages(message, history or [], inject_context=True)
try:
stream = await client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=messages,
temperature=0.3,
max_tokens=2048,
stream=True,
)
async for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
yield delta
except Exception as exc:
# Expose the real error to the client so debugging is easy
yield f"\n\n⚠️ **Agent error:** `{type(exc).__name__}: {exc}`\n\n"
yield "Common causes: missing or invalid GROQ_API_KEY, Groq rate limit hit, network issue."
# ── Sync invoke with tool use (Databricks / MLflow path) ──────────────────
def invoke(self, message: str, history: Optional[List[Dict]] = None) -> str:
"""Single-turn synchronous call β€” used by the MLflow PyFunc wrapper."""
client = self._get_sync_client()
messages = self._build_messages(message, history or [], inject_context=False)
# Let the model decide if it wants to search
response = client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=messages,
tools=TOOLS,
tool_choice="auto",
temperature=0.2,
max_tokens=2048,
)
msg = response.choices[0].message
# Tool-resolution loop (max 3 iterations to prevent infinite cycles)
for _ in range(3):
if not msg.tool_calls:
break
messages.append(msg)
for tc in msg.tool_calls:
args = json.loads(tc.function.arguments)
if tc.function.name == "search_knowledge_base":
tool_result = self._build_context(args["query"], args.get("k", 5))
else:
tool_result = f"Unknown tool: {tc.function.name}"
messages.append(
{"role": "tool", "tool_call_id": tc.id, "content": tool_result}
)
response = client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=messages,
tools=TOOLS,
tool_choice="auto",
temperature=0.2,
max_tokens=2048,
)
msg = response.choices[0].message
return msg.content or "(no content generated)"
# ── MLflow PyFunc interface ───────────────────────────────────────────────
def predict(self, context, model_input) -> str:
import pandas as pd
if isinstance(model_input, pd.DataFrame):
row = model_input.iloc[0]
message = row.get("message", "")
history = row.get("history", [])
if isinstance(history, str):
history = json.loads(history)
else:
message = model_input.get("message", "")
history = model_input.get("history", [])
return self.invoke(message=message, history=history)
# ──────────────────────────────────────────────────────────────────────────────
# MLflow wrapper (for Databricks Model Serving registration)
# ──────────────────────────────────────────────────────────────────────────────
class DEAgentPyFunc:
def load_context(self, context):
pdf_path = context.artifacts.get(
"pdf_path", "knowledge/data_engineering_patterns.pdf"
)
groq_key = os.environ.get("GROQ_API_KEY", "")
self.rag = DataEngineeringRAG(pdf_path=pdf_path, groq_api_key=groq_key)
self.rag.initialize()
self.agent = DataEngineeringAgent(rag=self.rag, groq_api_key=groq_key)
def predict(self, context, model_input):
return self.agent.predict(context, model_input)