Spaces:
Sleeping
Sleeping
| """ | |
| Databricks-Compatible MLflow Agent β Data Engineering Knowledge Assistant | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β’ Structured as an MLflow PyFunc model so it can be logged + served on Databricks | |
| β’ Uses Groq (llama-3.1-8b-instant) for ultra-low-latency responses | |
| β’ Streaming path: direct RAG (retrieve β stuff β stream) β simple, reliable | |
| β’ Sync path: tool-calling agent (search, code_gen) for richer Databricks demos | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| from typing import AsyncIterator, List, Dict, Optional | |
| from rag import DataEngineeringRAG | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # System prompt | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """You are an elite Data Engineering Knowledge Assistant, \ | |
| specialising in production-grade data pipelines, architecture patterns, and Databricks. | |
| Your knowledge comes from "Data Engineering Design Patterns" β a comprehensive guide \ | |
| to solving real data engineering problems. | |
| Guidelines: | |
| 1. Ground every answer in the retrieved context provided below. | |
| 2. Give concrete, code-inclusive answers when relevant (PySpark / Python / SQL). | |
| 3. Reference specific patterns by name (Lambda, Kappa, Medallion, Lakehouse, CDC, etc.). | |
| 4. Be direct and technical β the user is a practising data engineer. | |
| 5. If the retrieved context doesn't cover the question, say so β never fabricate. | |
| Format: | |
| - Direct answer first | |
| - Code blocks with ```python or ```sql | |
| - Pattern names in **bold** | |
| - End with a "π‘ Pro tip:" line when you have a non-obvious insight | |
| """ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tool schemas (used by sync invoke() for the Databricks demo path) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TOOLS = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "search_knowledge_base", | |
| "description": "Retrieve relevant chunks from the Data Engineering Design Patterns book.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": {"type": "string"}, | |
| "k": {"type": "integer", "default": 5}, | |
| }, | |
| "required": ["query"], | |
| }, | |
| }, | |
| } | |
| ] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Agent | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DataEngineeringAgent: | |
| def __init__(self, rag: DataEngineeringRAG, groq_api_key: str): | |
| self.rag = rag | |
| self.groq_api_key = groq_api_key | |
| self._sync_client = None | |
| self._async_client = None | |
| # ββ Groq clients (lazy init) ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_sync_client(self): | |
| if self._sync_client is None: | |
| from groq import Groq | |
| self._sync_client = Groq(api_key=self.groq_api_key) | |
| return self._sync_client | |
| def _get_async_client(self): | |
| if self._async_client is None: | |
| from groq import AsyncGroq | |
| self._async_client = AsyncGroq(api_key=self.groq_api_key) | |
| return self._async_client | |
| # ββ Context builder βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PDF extractors often emit these invisible / structural Unicode chars. | |
| # In containers with an ASCII-only default locale (common on minimal Docker | |
| # images), the HTTP client can fail with `UnicodeEncodeError: 'ascii' codec` | |
| # when serialising them. Strip them at the source. | |
| _UNICODE_SCRUB = str.maketrans({ | |
| "\u2028": "\n", # LINE SEPARATOR | |
| "\u2029": "\n\n", # PARAGRAPH SEPARATOR | |
| "\u200b": "", # ZERO WIDTH SPACE | |
| "\u200c": "", # ZERO WIDTH NON-JOINER | |
| "\u200d": "", # ZERO WIDTH JOINER | |
| "\ufeff": "", # BYTE ORDER MARK | |
| "\x00": "", # NULL | |
| "\xa0": " ", # NON-BREAKING SPACE | |
| }) | |
| def _sanitize(cls, text: str) -> str: | |
| return (text or "").translate(cls._UNICODE_SCRUB) | |
| def _build_context(self, query: str, k: int = 5) -> str: | |
| """Retrieve top-k chunks and format as prompt context.""" | |
| chunks = self.rag.search(query, k=k) | |
| if not chunks: | |
| return "(No relevant context found in the knowledge base.)" | |
| formatted = [] | |
| for i, c in enumerate(chunks, 1): | |
| formatted.append( | |
| f"[Source {i} Β· Page {c['page']} Β· Relevance {c['score']:.2f}]\n" | |
| f"{self._sanitize(c['content'])}" | |
| ) | |
| return "\n\n---\n\n".join(formatted) | |
| def _build_messages( | |
| self, user_message: str, history: List[Dict], inject_context: bool = True | |
| ) -> List[Dict]: | |
| """Build the chat-completions messages array.""" | |
| system = SYSTEM_PROMPT | |
| if inject_context: | |
| context = self._build_context(user_message, k=5) | |
| system += f"\n\nβββ RETRIEVED CONTEXT βββ\n{context}\nββββββββββββββββββββββββ" | |
| messages = [{"role": "system", "content": system}] | |
| # Keep last 3 exchanges (6 messages) for continuity | |
| for turn in history[-6:]: | |
| messages.append({"role": turn["role"], "content": turn["content"]}) | |
| messages.append({"role": "user", "content": user_message}) | |
| return messages | |
| # ββ Async streaming (used by the FastAPI /api/chat endpoint) ββββββββββββββ | |
| async def astream( | |
| self, message: str, history: Optional[List[Dict]] = None | |
| ) -> AsyncIterator[str]: | |
| """ | |
| Streaming RAG response. Yields string chunks as the model generates. | |
| First-token latency on Groq free tier: ~150-300 ms. | |
| """ | |
| client = self._get_async_client() | |
| messages = self._build_messages(message, history or [], inject_context=True) | |
| try: | |
| stream = await client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=messages, | |
| temperature=0.3, | |
| max_tokens=2048, | |
| stream=True, | |
| ) | |
| async for chunk in stream: | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| yield delta | |
| except Exception as exc: | |
| # Expose the real error to the client so debugging is easy | |
| yield f"\n\nβ οΈ **Agent error:** `{type(exc).__name__}: {exc}`\n\n" | |
| yield "Common causes: missing or invalid GROQ_API_KEY, Groq rate limit hit, network issue." | |
| # ββ Sync invoke with tool use (Databricks / MLflow path) ββββββββββββββββββ | |
| def invoke(self, message: str, history: Optional[List[Dict]] = None) -> str: | |
| """Single-turn synchronous call β used by the MLflow PyFunc wrapper.""" | |
| client = self._get_sync_client() | |
| messages = self._build_messages(message, history or [], inject_context=False) | |
| # Let the model decide if it wants to search | |
| response = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=messages, | |
| tools=TOOLS, | |
| tool_choice="auto", | |
| temperature=0.2, | |
| max_tokens=2048, | |
| ) | |
| msg = response.choices[0].message | |
| # Tool-resolution loop (max 3 iterations to prevent infinite cycles) | |
| for _ in range(3): | |
| if not msg.tool_calls: | |
| break | |
| messages.append(msg) | |
| for tc in msg.tool_calls: | |
| args = json.loads(tc.function.arguments) | |
| if tc.function.name == "search_knowledge_base": | |
| tool_result = self._build_context(args["query"], args.get("k", 5)) | |
| else: | |
| tool_result = f"Unknown tool: {tc.function.name}" | |
| messages.append( | |
| {"role": "tool", "tool_call_id": tc.id, "content": tool_result} | |
| ) | |
| response = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=messages, | |
| tools=TOOLS, | |
| tool_choice="auto", | |
| temperature=0.2, | |
| max_tokens=2048, | |
| ) | |
| msg = response.choices[0].message | |
| return msg.content or "(no content generated)" | |
| # ββ MLflow PyFunc interface βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(self, context, model_input) -> str: | |
| import pandas as pd | |
| if isinstance(model_input, pd.DataFrame): | |
| row = model_input.iloc[0] | |
| message = row.get("message", "") | |
| history = row.get("history", []) | |
| if isinstance(history, str): | |
| history = json.loads(history) | |
| else: | |
| message = model_input.get("message", "") | |
| history = model_input.get("history", []) | |
| return self.invoke(message=message, history=history) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MLflow wrapper (for Databricks Model Serving registration) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DEAgentPyFunc: | |
| def load_context(self, context): | |
| pdf_path = context.artifacts.get( | |
| "pdf_path", "knowledge/data_engineering_patterns.pdf" | |
| ) | |
| groq_key = os.environ.get("GROQ_API_KEY", "") | |
| self.rag = DataEngineeringRAG(pdf_path=pdf_path, groq_api_key=groq_key) | |
| self.rag.initialize() | |
| self.agent = DataEngineeringAgent(rag=self.rag, groq_api_key=groq_key) | |
| def predict(self, context, model_input): | |
| return self.agent.predict(context, model_input) |